Softmax implementation
Opened this issue · 2 comments
Hi @markus-eberts ,
thanks for sharing your great work.
I was playing around a variation of spERT, where the relations where extracted using a softmax instead of a sigmoid.
To ensure the correctness of the overall system I trained it with the version of conll04 that you provided with the model and everything seemed fine.
The issues arose when trying to train it with a different dataset, converted to a format compatible with spERT.
Train went smoothly, but the model didn't make any prediction at all, be it an entity or relation. I am for sure missing something, I was wondering if you could maybe provide to me a direction from which start to work.
Here is a single sample from the training dataset:
{"tokens": ["The", "role", "of", "p27(Kip1", ")", "in", "dasatinib-enhanced", "paclitaxel", "cytotoxicity", "in", "human", "ovarian", "cancer", "cells", ".", "\r\n"], "entities": [{"type": "drug", "start": 6, "end": 7}, {"type": "drug", "start": 7, "end": 8}], "relations": [{"type": "effect", "head": 0, "tail": 1}], "orig_id": "DDI-MedLine.d194.s0"}
On this dataset the softmax is recommended since all the relations are symmetrical and between two entities exists only a single relation.
Here is the log of the training run:
Config:
{'label': 'softmax_ddi', 'model_type': 'spert', 'model_path': 'bert-base-cased', 'tokenizer_path': 'bert-base-cased', 'train_path': 'data/datasets/unibs/train/all.json', 'valid_path': 'data/datasets/unibs/dev/all.json', 'types_path': 'data/datasets/unibs/types.json', 'train_batch_size': '2', 'eval_batch_size': '1', 'neg_entity_count': '100', 'neg_relation_count': '100', 'epochs': '5', 'lr': '5e-5', 'lr_warmup': '0.1', 'weight_decay': '0.01', 'max_grad_norm': '1.0', 'rel_filter_threshold': '0.4', 'size_embedding': '25', 'prop_drop': '0.1', 'max_span_size': '10', 'store_predictions': 'true', 'store_examples': 'true', 'sampling_processes': '4', 'max_pairs': '1000', 'final_eval': 'true', 'log_path': 'data/log/', 'save_path': 'data/save/'}
Repeat 1 timesIteration 0
2021-05-29 09:45:30,631 [MainThread ] [INFO ] Datasets: data/datasets/unibs/train/all.json, data/datasets/unibs/dev/all.json
2021-05-29 09:45:30,631 [MainThread ] [INFO ] Model type: spert
Parse dataset 'train': 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 6045/6045 [00:09<00:00, 622.16it/s]
Parse dataset 'valid': 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 931/931 [00:01<00:00, 609.67it/s]
2021-05-29 09:45:41,928 [MainThread ] [INFO ] Relation type count: 5
2021-05-29 09:45:41,928 [MainThread ] [INFO ] Entity type count: 5
2021-05-29 09:45:41,928 [MainThread ] [INFO ] Entities:
2021-05-29 09:45:41,928 [MainThread ] [INFO ] No Entity=0
2021-05-29 09:45:41,928 [MainThread ] [INFO ] Drug name=1
2021-05-29 09:45:41,928 [MainThread ] [INFO ] Drug=2
2021-05-29 09:45:41,928 [MainThread ] [INFO ] Group=3
2021-05-29 09:45:41,928 [MainThread ] [INFO ] Brand=4
2021-05-29 09:45:41,928 [MainThread ] [INFO ] Relations:
2021-05-29 09:45:41,928 [MainThread ] [INFO ] No Relation=0
2021-05-29 09:45:41,929 [MainThread ] [INFO ] Effect=1
2021-05-29 09:45:41,929 [MainThread ] [INFO ] Int=2
2021-05-29 09:45:41,929 [MainThread ] [INFO ] Mechanism=3
2021-05-29 09:45:41,929 [MainThread ] [INFO ] Advise=4
2021-05-29 09:45:41,929 [MainThread ] [INFO ] Dataset: train
2021-05-29 09:45:41,929 [MainThread ] [INFO ] Document count: 6045
2021-05-29 09:45:41,929 [MainThread ] [INFO ] Relation count: 3378
2021-05-29 09:45:41,929 [MainThread ] [INFO ] Entity count: 12549
2021-05-29 09:45:41,929 [MainThread ] [INFO ] Dataset: valid
2021-05-29 09:45:41,929 [MainThread ] [INFO ] Document count: 931
2021-05-29 09:45:41,929 [MainThread ] [INFO ] Relation count: 642
2021-05-29 09:45:41,929 [MainThread ] [INFO ] Entity count: 2216
2021-05-29 09:45:41,929 [MainThread ] [INFO ] Updates per epoch: 3022
2021-05-29 09:45:41,929 [MainThread ] [INFO ] Updates total: 15110[...]
Evaluation
--- Entities (named entity recognition (NER)) ---
An entity is considered correct if the entity type and span is predicted correctlytype precision recall f1-score support drug_n 0.00 0.00 0.00 101.0 drug 0.00 0.00 0.00 1396.0 group 0.00 0.00 0.00 538.0 brand 0.00 0.00 0.00 169.0 micro 0.00 0.00 0.00 2204.0 macro 0.00 0.00 0.00 2204.0
--- Relations ---
Without named entity classification (NEC)
A relation is considered correct if the relation type and the spans of the two related entities are predicted correctly (entity type is not considered)type precision recall f1-score support advise 0.00 0.00 0.00 130.0 int 0.00 0.00 0.00 8.0 effect 0.00 0.00 0.00 250.0 mechanism 0.00 0.00 0.00 253.0 micro 0.00 0.00 0.00 641.0 macro 0.00 0.00 0.00 641.0
With named entity classification (NEC)
A relation is considered correct if the relation type and the two related entities are predicted correctly (in span and entity type)type precision recall f1-score support advise 0.00 0.00 0.00 130.0 int 0.00 0.00 0.00 8.0 effect 0.00 0.00 0.00 250.0 mechanism 0.00 0.00 0.00 253.0 micro 0.00 0.00 0.00 641.0 macro 0.00 0.00 0.00 641.0
2021-05-29 12:21:19,887 [MainThread ] [INFO ] Logged in: data/log/softmax_ddi/2021-05-29_09-45-29.875587
2021-05-29 12:21:19,887 [MainThread ] [INFO ] Saved in: data/save/softmax_ddi/2021-05-29_09-45-29.875587
The following are the major changes that I applied to the original model:
spert/spert_trainer.py
class SpERTTrainer(BaseTrainer):
config=config,
# SpERT model parameters
cls_token=self._tokenizer.convert_tokens_to_ids('[CLS]'),
- relation_types=input_reader.relation_type_count - 1,
+ relation_types=input_reader.relation_type_count,
entity_types=input_reader.entity_type_count,
max_pairs=self._args.max_pairs,
prop_drop=self._args.prop_drop,
class SpERTTrainer(BaseTrainer):
num_warmup_steps=args.lr_warmup * updates_total,
num_training_steps=updates_total)
# create loss function
- rel_criterion = torch.nn.BCEWithLogitsLoss(reduction='none')
+ rel_criterion = torch.nn.CrossEntropyLoss(reduction='none')
entity_criterion = torch.nn.CrossEntropyLoss(reduction='none')
spert/loss.py
class SpERTLoss(Loss):
if rel_count.item() != 0:
rel_logits = rel_logits.view(-1, rel_logits.shape[-1])
- rel_types = rel_types.view(-1, rel_types.shape[-1])
+ rel_types = rel_types.view(-1)
rel_loss = self._rel_criterion(rel_logits, rel_types)
- rel_loss = rel_loss.sum(-1) / rel_loss.shape[-1]
rel_loss = (rel_loss * rel_sample_masks).sum() / rel_count
# joint loss
spert/sampling.py
def create_train_sample(doc, neg_entity_count: int, neg_rel_count: int, max_span
rel_sample_masks = torch.zeros([1], dtype=torch.bool)
# relation types to one-hot encoding
- rel_types_onehot = torch.zeros([rel_types.shape[0], rel_type_count], dtype=torch.float32)
- rel_types_onehot.scatter_(1, rel_types.unsqueeze(1), 1)
- rel_types_onehot = rel_types_onehot[:, 1:] # all zeros for 'none' relation
return dict(encodings=encodings, context_masks=context_masks, entity_masks=entity_masks,
entity_sizes=entity_sizes, entity_types=entity_types,
- rels=rels, rel_masks=rel_masks, rel_types=rel_types_onehot,
+ rels=rels, rel_masks=rel_masks, rel_types=rel_types,
entity_sample_masks=entity_sample_masks, rel_sample_masks=rel_sample_masks)
spert/models.py
class SpERT(BertPreTrainedModel):
chunk_rel_logits = self._classify_relations(entity_spans_pool, size_embeddings,
relations, rel_masks, h_large, i)
# apply sigmoid
- chunk_rel_clf = torch.sigmoid(chunk_rel_logits)
- rel_clf[:, i:i + self._max_pairs, :] = chunk_rel_clf
+ rel_clf[:, i:i + self._max_pairs, :] = chunk_rel_logits
- rel_clf = rel_clf * rel_sample_masks # mask
# apply softmax
entity_clf = torch.softmax(entity_clf, dim=2)
+ rel_clf = torch.softmax(rel_clf, dim=2)
+ rel_clf *= rel_sample_masks
return entity_clf, rel_clf, relations
spert/predictions.py
def convert_predictions(batch_entity_clf: torch.tensor, batch_rel_clf: torch.ten
batch_entity_types *= batch['entity_sample_masks'].long()
# apply threshold to relations
- batch_rel_clf[batch_rel_clf < rel_filter_threshold] = 0
batch_pred_entities = []
batch_pred_relations = []
spert/predictions.py
def _convert_pred_relations(rel_clf: torch.tensor, rels: torch.tensor,
entity_types: torch.tensor, entity_spans: torch.tensor, input_reader: BaseInputReader):
- rel_class_count = rel_clf.shape[1]
- rel_clf = rel_clf.view(-1)
# get predicted relation labels and corresponding entity pairs
- rel_nonzero = rel_clf.nonzero().view(-1)
- pred_rel_scores = rel_clf[rel_nonzero]
-
- pred_rel_types = (rel_nonzero % rel_class_count) + 1 # model does not predict None class (+1)
- valid_rel_indices = rel_nonzero // rel_class_count
+ valid_rel_indices = torch.nonzero(torch.sum(rel_clf, dim=-1)).view(-1)
+ valid_rel_indices = valid_rel_indices.view(-1)
+
+ pred_rel_types = rel_clf[valid_rel_indices]
+ if pred_rel_types.shape[0] != 0:
+ pred_rel_types = pred_rel_types.argmax(dim=-1)
+ valid_rel_indices = torch.nonzero(pred_rel_types).view(-1)
+
+ pred_rel_types = pred_rel_types[valid_rel_indices]
+
+ pred_rel_scores = rel_clf[valid_rel_indices]
+ if pred_rel_scores.shape[0] != 0:
+ pred_rel_scores = pred_rel_scores.max(dim=-1)[0]
valid_rels = rels[valid_rel_indices]
Not related to the previous topic, thought I'd add it here since the same dataset is involved.
During the experimentation with the original spERT I changed bert to scibert. Using 1 epoch of training I had no issues whatsoever, when I increased them to 5 the procedure to store the predictions started to pick up relations that should instead be filtered out by previous elaboration (if I interpreted everything correctly).
Here is the log
Config:
{'label': 'scibert_ddi', 'model_type': 'spert', 'model_path': '/home/deeplearning/Salvalai/scibert_scivocab_uncased', 'tokenizer_path': '/home/deeplearning/Salvalai/scibert_scivocab_uncased', 'train_path': 'data/datasets/unibs/train/all.json', 'valid_path': 'data/datasets/unibs/dev/all.json', 'types_path': 'data/datasets/unibs/types.json', 'train_batch_size': '2', 'eval_batch_size': '1', 'neg_entity_count': '100', 'neg_relation_count': '100', 'epochs': '5', 'lr': '5e-5', 'lr_warmup': '0.1', 'weight_decay': '0.01', 'max_grad_norm': '1.0', 'rel_filter_threshold': '0.4', 'size_embedding': '25', 'prop_drop': '0.1', 'max_span_size': '10', 'store_predictions': 'true', 'store_examples': 'true', 'sampling_processes': '4', 'max_pairs': '1000', 'final_eval': 'true', 'log_path': 'data/log/', 'save_path': 'data/save/'}
Repeat 1 timesIteration 0
2021-05-28 10:54:28,162 [MainThread ] [INFO ] Datasets: data/datasets/unibs/train/all.json, data/datasets/unibs/dev/all.json
2021-05-28 10:54:28,162 [MainThread ] [INFO ] Model type: spert
Parse dataset 'train': 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 6045/6045 [00:12<00:00, 466.41it/s]
Parse dataset 'valid': 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 931/931 [00:02<00:00, 359.59it/s]
2021-05-28 10:54:43,771 [MainThread ] [INFO ] Relation type count: 5
2021-05-28 10:54:43,771 [MainThread ] [INFO ] Entity type count: 5
2021-05-28 10:54:43,771 [MainThread ] [INFO ] Entities:
2021-05-28 10:54:43,772 [MainThread ] [INFO ] No Entity=0
2021-05-28 10:54:43,772 [MainThread ] [INFO ] Drug name=1
2021-05-28 10:54:43,772 [MainThread ] [INFO ] Drug=2
2021-05-28 10:54:43,772 [MainThread ] [INFO ] Group=3
2021-05-28 10:54:43,772 [MainThread ] [INFO ] Brand=4
2021-05-28 10:54:43,772 [MainThread ] [INFO ] Relations:
2021-05-28 10:54:43,772 [MainThread ] [INFO ] No Relation=0
2021-05-28 10:54:43,772 [MainThread ] [INFO ] Effect=1
2021-05-28 10:54:43,772 [MainThread ] [INFO ] Int=2
2021-05-28 10:54:43,772 [MainThread ] [INFO ] Mechanism=3
2021-05-28 10:54:43,772 [MainThread ] [INFO ] Advise=4
2021-05-28 10:54:43,772 [MainThread ] [INFO ] Dataset: train
2021-05-28 10:54:43,772 [MainThread ] [INFO ] Document count: 6045
2021-05-28 10:54:43,772 [MainThread ] [INFO ] Relation count: 3378
2021-05-28 10:54:43,772 [MainThread ] [INFO ] Entity count: 12549
2021-05-28 10:54:43,772 [MainThread ] [INFO ] Dataset: valid
2021-05-28 10:54:43,772 [MainThread ] [INFO ] Document count: 931
2021-05-28 10:54:43,773 [MainThread ] [INFO ] Relation count: 642
2021-05-28 10:54:43,773 [MainThread ] [INFO ] Entity count: 2216
2021-05-28 10:54:43,773 [MainThread ] [INFO ] Updates per epoch: 3022
2021-05-28 10:54:43,773 [MainThread ] [INFO ] Updates total: 15110[...]
Evaluation
--- Entities (named entity recognition (NER)) ---
An entity is considered correct if the entity type and span is predicted correctlytype precision recall f1-score support brand 0.00 0.00 0.00 169.0 drug 0.00 0.00 0.00 1396.0 drug_n 0.00 0.00 0.00 101.0 group 0.00 0.00 0.00 538.0 micro 0.00 0.00 0.00 2204.0 macro 0.00 0.00 0.00 2204.0
--- Relations ---
Without named entity classification (NEC)
A relation is considered correct if the relation type and the spans of the two related entities are predicted correctly (entity type is not considered)type precision recall f1-score support advise 0.00 0.00 0.00 130.0 mechanism 0.00 0.00 0.00 253.0 effect 0.00 0.00 0.00 250.0 int 0.00 0.00 0.00 8.0 micro 0.00 0.00 0.00 641.0 macro 0.00 0.00 0.00 641.0
With named entity classification (NEC)
A relation is considered correct if the relation type and the two related entities are predicted correctly (in span and entity type)type precision recall f1-score support advise 0.00 0.00 0.00 130.0 mechanism 0.00 0.00 0.00 253.0 effect 0.00 0.00 0.00 250.0 int 0.00 0.00 0.00 8.0 micro 0.00 0.00 0.00 641.0 macro 0.00 0.00 0.00 641.0
Process SpawnProcess-1:
Traceback (most recent call last):
File "/home/deeplearning/.conda/envs/salvalai/lib/python3.8/multiprocessing/process.py", line 315, in _bootstrap
self.run()
File "/home/deeplearning/.conda/envs/salvalai/lib/python3.8/multiprocessing/process.py", line 108, in run
self._target(*self._args, **self._kwargs)
File "/home/deeplearning/Salvalai/spert/spert.py", line 16, in __train
trainer.train(train_path=run_args.train_path, valid_path=run_args.valid_path,
File "/home/deeplearning/Salvalai/spert/spert/spert_trainer.py", line 97, in train
self._eval(model, validation_dataset, input_reader, epoch + 1, updates_epoch)
File "/home/deeplearning/Salvalai/spert/spert/spert_trainer.py", line 253, in _eval
evaluator.store_predictions()
File "/home/deeplearning/Salvalai/spert/spert/evaluator.py", line 87, in store_predictions
prediction.store_predictions(self._dataset.documents, self._pred_entities,
File "/home/deeplearning/Salvalai/spert/spert/prediction.py", line 196, in store_predictions
head_idx = converted_entities.index(converted_head)
ValueError: {'type': 'None', 'start': 0, 'end': 1} is not in list
Best regards
Hi,
I just pushed a corner case handling (commit e0d9aee) which may be related to your problem. In some cases (especially strings containing only control characters) the tokenizer we are using maps tokens to empty sequences, which could lead to zero divisions (and NaN values) down the road. Can you please check if the commit fixes your problem? If so, it would still be better to remove any control characters from your dataset beforehand.
If this does not fix your isuse, could you please send me the dataset (or a representative part of it) by email (markus.eberts@hs-rm.de)? I can have a look at it then.
Thank you, I'll check asap and let you know