lavis-nlp/spert

Softmax implementation

Opened this issue · 2 comments

Hi @markus-eberts ,
thanks for sharing your great work.

I was playing around a variation of spERT, where the relations where extracted using a softmax instead of a sigmoid.
To ensure the correctness of the overall system I trained it with the version of conll04 that you provided with the model and everything seemed fine.
The issues arose when trying to train it with a different dataset, converted to a format compatible with spERT.
Train went smoothly, but the model didn't make any prediction at all, be it an entity or relation. I am for sure missing something, I was wondering if you could maybe provide to me a direction from which start to work.

Here is a single sample from the training dataset:

{"tokens": ["The", "role", "of", "p27(Kip1", ")", "in", "dasatinib-enhanced", "paclitaxel", "cytotoxicity", "in", "human", "ovarian", "cancer", "cells", ".", "\r\n"], "entities": [{"type": "drug", "start": 6, "end": 7}, {"type": "drug", "start": 7, "end": 8}], "relations": [{"type": "effect", "head": 0, "tail": 1}], "orig_id": "DDI-MedLine.d194.s0"}

On this dataset the softmax is recommended since all the relations are symmetrical and between two entities exists only a single relation.

Here is the log of the training run:


Config:
{'label': 'softmax_ddi', 'model_type': 'spert', 'model_path': 'bert-base-cased', 'tokenizer_path': 'bert-base-cased', 'train_path': 'data/datasets/unibs/train/all.json', 'valid_path': 'data/datasets/unibs/dev/all.json', 'types_path': 'data/datasets/unibs/types.json', 'train_batch_size': '2', 'eval_batch_size': '1', 'neg_entity_count': '100', 'neg_relation_count': '100', 'epochs': '5', 'lr': '5e-5', 'lr_warmup': '0.1', 'weight_decay': '0.01', 'max_grad_norm': '1.0', 'rel_filter_threshold': '0.4', 'size_embedding': '25', 'prop_drop': '0.1', 'max_span_size': '10', 'store_predictions': 'true', 'store_examples': 'true', 'sampling_processes': '4', 'max_pairs': '1000', 'final_eval': 'true', 'log_path': 'data/log/', 'save_path': 'data/save/'}
Repeat 1 times

Iteration 0

2021-05-29 09:45:30,631 [MainThread ] [INFO ] Datasets: data/datasets/unibs/train/all.json, data/datasets/unibs/dev/all.json
2021-05-29 09:45:30,631 [MainThread ] [INFO ] Model type: spert
Parse dataset 'train': 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 6045/6045 [00:09<00:00, 622.16it/s]
Parse dataset 'valid': 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 931/931 [00:01<00:00, 609.67it/s]
2021-05-29 09:45:41,928 [MainThread ] [INFO ] Relation type count: 5
2021-05-29 09:45:41,928 [MainThread ] [INFO ] Entity type count: 5
2021-05-29 09:45:41,928 [MainThread ] [INFO ] Entities:
2021-05-29 09:45:41,928 [MainThread ] [INFO ] No Entity=0
2021-05-29 09:45:41,928 [MainThread ] [INFO ] Drug name=1
2021-05-29 09:45:41,928 [MainThread ] [INFO ] Drug=2
2021-05-29 09:45:41,928 [MainThread ] [INFO ] Group=3
2021-05-29 09:45:41,928 [MainThread ] [INFO ] Brand=4
2021-05-29 09:45:41,928 [MainThread ] [INFO ] Relations:
2021-05-29 09:45:41,928 [MainThread ] [INFO ] No Relation=0
2021-05-29 09:45:41,929 [MainThread ] [INFO ] Effect=1
2021-05-29 09:45:41,929 [MainThread ] [INFO ] Int=2
2021-05-29 09:45:41,929 [MainThread ] [INFO ] Mechanism=3
2021-05-29 09:45:41,929 [MainThread ] [INFO ] Advise=4
2021-05-29 09:45:41,929 [MainThread ] [INFO ] Dataset: train
2021-05-29 09:45:41,929 [MainThread ] [INFO ] Document count: 6045
2021-05-29 09:45:41,929 [MainThread ] [INFO ] Relation count: 3378
2021-05-29 09:45:41,929 [MainThread ] [INFO ] Entity count: 12549
2021-05-29 09:45:41,929 [MainThread ] [INFO ] Dataset: valid
2021-05-29 09:45:41,929 [MainThread ] [INFO ] Document count: 931
2021-05-29 09:45:41,929 [MainThread ] [INFO ] Relation count: 642
2021-05-29 09:45:41,929 [MainThread ] [INFO ] Entity count: 2216
2021-05-29 09:45:41,929 [MainThread ] [INFO ] Updates per epoch: 3022
2021-05-29 09:45:41,929 [MainThread ] [INFO ] Updates total: 15110

[...]

Evaluation

--- Entities (named entity recognition (NER)) ---
An entity is considered correct if the entity type and span is predicted correctly

           type    precision       recall     f1-score      support
         drug_n         0.00         0.00         0.00        101.0
           drug         0.00         0.00         0.00       1396.0
          group         0.00         0.00         0.00        538.0
          brand         0.00         0.00         0.00        169.0

          micro         0.00         0.00         0.00       2204.0
          macro         0.00         0.00         0.00       2204.0

--- Relations ---

Without named entity classification (NEC)
A relation is considered correct if the relation type and the spans of the two related entities are predicted correctly (entity type is not considered)

           type    precision       recall     f1-score      support
         advise         0.00         0.00         0.00        130.0
            int         0.00         0.00         0.00          8.0
         effect         0.00         0.00         0.00        250.0
      mechanism         0.00         0.00         0.00        253.0

          micro         0.00         0.00         0.00        641.0
          macro         0.00         0.00         0.00        641.0

With named entity classification (NEC)
A relation is considered correct if the relation type and the two related entities are predicted correctly (in span and entity type)

           type    precision       recall     f1-score      support
         advise         0.00         0.00         0.00        130.0
            int         0.00         0.00         0.00          8.0
         effect         0.00         0.00         0.00        250.0
      mechanism         0.00         0.00         0.00        253.0

          micro         0.00         0.00         0.00        641.0
          macro         0.00         0.00         0.00        641.0

2021-05-29 12:21:19,887 [MainThread ] [INFO ] Logged in: data/log/softmax_ddi/2021-05-29_09-45-29.875587
2021-05-29 12:21:19,887 [MainThread ] [INFO ] Saved in: data/save/softmax_ddi/2021-05-29_09-45-29.875587

The following are the major changes that I applied to the original model:
spert/spert_trainer.py

 class SpERTTrainer(BaseTrainer):
                                             config=config,
                                             # SpERT model parameters
                                             cls_token=self._tokenizer.convert_tokens_to_ids('[CLS]'),
-                                            relation_types=input_reader.relation_type_count - 1,
+                                            relation_types=input_reader.relation_type_count,
                                             entity_types=input_reader.entity_type_count,
                                             max_pairs=self._args.max_pairs,
                                             prop_drop=self._args.prop_drop,
 class SpERTTrainer(BaseTrainer):
                                                                  num_warmup_steps=args.lr_warmup * updates_total,
                                                                  num_training_steps=updates_total)
         # create loss function
-        rel_criterion = torch.nn.BCEWithLogitsLoss(reduction='none')
+        rel_criterion = torch.nn.CrossEntropyLoss(reduction='none')
         entity_criterion = torch.nn.CrossEntropyLoss(reduction='none')

spert/loss.py

 class SpERTLoss(Loss):
 
         if rel_count.item() != 0:
             rel_logits = rel_logits.view(-1, rel_logits.shape[-1])
-            rel_types = rel_types.view(-1, rel_types.shape[-1])
+            rel_types = rel_types.view(-1)

             rel_loss = self._rel_criterion(rel_logits, rel_types)
-            rel_loss = rel_loss.sum(-1) / rel_loss.shape[-1]
             rel_loss = (rel_loss * rel_sample_masks).sum() / rel_count
 
             # joint loss

spert/sampling.py

 def create_train_sample(doc, neg_entity_count: int, neg_rel_count: int, max_span
         rel_sample_masks = torch.zeros([1], dtype=torch.bool)
 
     # relation types to one-hot encoding
-    rel_types_onehot = torch.zeros([rel_types.shape[0], rel_type_count], dtype=torch.float32)
-    rel_types_onehot.scatter_(1, rel_types.unsqueeze(1), 1)
-    rel_types_onehot = rel_types_onehot[:, 1:]  # all zeros for 'none' relation
 
     return dict(encodings=encodings, context_masks=context_masks, entity_masks=entity_masks,
                 entity_sizes=entity_sizes, entity_types=entity_types,
-                rels=rels, rel_masks=rel_masks, rel_types=rel_types_onehot,
+                rels=rels, rel_masks=rel_masks, rel_types=rel_types,
                 entity_sample_masks=entity_sample_masks, rel_sample_masks=rel_sample_masks)

spert/models.py

 class SpERT(BertPreTrainedModel):
             chunk_rel_logits = self._classify_relations(entity_spans_pool, size_embeddings,
                                                         relations, rel_masks, h_large, i)
             # apply sigmoid
-            chunk_rel_clf = torch.sigmoid(chunk_rel_logits)
-            rel_clf[:, i:i + self._max_pairs, :] = chunk_rel_clf
+            rel_clf[:, i:i + self._max_pairs, :] = chunk_rel_logits
 
-        rel_clf = rel_clf * rel_sample_masks  # mask
 
         # apply softmax
         entity_clf = torch.softmax(entity_clf, dim=2)
+        rel_clf = torch.softmax(rel_clf, dim=2)
+        rel_clf *= rel_sample_masks
 
         return entity_clf, rel_clf, relations

spert/predictions.py

 def convert_predictions(batch_entity_clf: torch.tensor, batch_rel_clf: torch.ten
     batch_entity_types *= batch['entity_sample_masks'].long()
 
     # apply threshold to relations
-    batch_rel_clf[batch_rel_clf < rel_filter_threshold] = 0
 
     batch_pred_entities = []
     batch_pred_relations = []

spert/predictions.py

 def _convert_pred_relations(rel_clf: torch.tensor, rels: torch.tensor,
                             entity_types: torch.tensor, entity_spans: torch.tensor, input_reader: BaseInputReader):
-    rel_class_count = rel_clf.shape[1]
-    rel_clf = rel_clf.view(-1)
 
     # get predicted relation labels and corresponding entity pairs
-    rel_nonzero = rel_clf.nonzero().view(-1)
-    pred_rel_scores = rel_clf[rel_nonzero]
-
-    pred_rel_types = (rel_nonzero % rel_class_count) + 1  # model does not predict None class (+1)
-    valid_rel_indices = rel_nonzero // rel_class_count
+    valid_rel_indices = torch.nonzero(torch.sum(rel_clf, dim=-1)).view(-1)
+    valid_rel_indices = valid_rel_indices.view(-1)
+    
+    pred_rel_types = rel_clf[valid_rel_indices]
+    if pred_rel_types.shape[0] != 0:
+        pred_rel_types = pred_rel_types.argmax(dim=-1)
+        valid_rel_indices = torch.nonzero(pred_rel_types).view(-1)
+        
+        pred_rel_types = pred_rel_types[valid_rel_indices]
+
+    pred_rel_scores = rel_clf[valid_rel_indices]
+    if pred_rel_scores.shape[0] != 0:
+        pred_rel_scores = pred_rel_scores.max(dim=-1)[0]

     valid_rels = rels[valid_rel_indices]

Not related to the previous topic, thought I'd add it here since the same dataset is involved.
During the experimentation with the original spERT I changed bert to scibert. Using 1 epoch of training I had no issues whatsoever, when I increased them to 5 the procedure to store the predictions started to pick up relations that should instead be filtered out by previous elaboration (if I interpreted everything correctly).
Here is the log


Config:
{'label': 'scibert_ddi', 'model_type': 'spert', 'model_path': '/home/deeplearning/Salvalai/scibert_scivocab_uncased', 'tokenizer_path': '/home/deeplearning/Salvalai/scibert_scivocab_uncased', 'train_path': 'data/datasets/unibs/train/all.json', 'valid_path': 'data/datasets/unibs/dev/all.json', 'types_path': 'data/datasets/unibs/types.json', 'train_batch_size': '2', 'eval_batch_size': '1', 'neg_entity_count': '100', 'neg_relation_count': '100', 'epochs': '5', 'lr': '5e-5', 'lr_warmup': '0.1', 'weight_decay': '0.01', 'max_grad_norm': '1.0', 'rel_filter_threshold': '0.4', 'size_embedding': '25', 'prop_drop': '0.1', 'max_span_size': '10', 'store_predictions': 'true', 'store_examples': 'true', 'sampling_processes': '4', 'max_pairs': '1000', 'final_eval': 'true', 'log_path': 'data/log/', 'save_path': 'data/save/'}
Repeat 1 times

Iteration 0

2021-05-28 10:54:28,162 [MainThread ] [INFO ] Datasets: data/datasets/unibs/train/all.json, data/datasets/unibs/dev/all.json
2021-05-28 10:54:28,162 [MainThread ] [INFO ] Model type: spert
Parse dataset 'train': 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 6045/6045 [00:12<00:00, 466.41it/s]
Parse dataset 'valid': 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 931/931 [00:02<00:00, 359.59it/s]
2021-05-28 10:54:43,771 [MainThread ] [INFO ] Relation type count: 5
2021-05-28 10:54:43,771 [MainThread ] [INFO ] Entity type count: 5
2021-05-28 10:54:43,771 [MainThread ] [INFO ] Entities:
2021-05-28 10:54:43,772 [MainThread ] [INFO ] No Entity=0
2021-05-28 10:54:43,772 [MainThread ] [INFO ] Drug name=1
2021-05-28 10:54:43,772 [MainThread ] [INFO ] Drug=2
2021-05-28 10:54:43,772 [MainThread ] [INFO ] Group=3
2021-05-28 10:54:43,772 [MainThread ] [INFO ] Brand=4
2021-05-28 10:54:43,772 [MainThread ] [INFO ] Relations:
2021-05-28 10:54:43,772 [MainThread ] [INFO ] No Relation=0
2021-05-28 10:54:43,772 [MainThread ] [INFO ] Effect=1
2021-05-28 10:54:43,772 [MainThread ] [INFO ] Int=2
2021-05-28 10:54:43,772 [MainThread ] [INFO ] Mechanism=3
2021-05-28 10:54:43,772 [MainThread ] [INFO ] Advise=4
2021-05-28 10:54:43,772 [MainThread ] [INFO ] Dataset: train
2021-05-28 10:54:43,772 [MainThread ] [INFO ] Document count: 6045
2021-05-28 10:54:43,772 [MainThread ] [INFO ] Relation count: 3378
2021-05-28 10:54:43,772 [MainThread ] [INFO ] Entity count: 12549
2021-05-28 10:54:43,772 [MainThread ] [INFO ] Dataset: valid
2021-05-28 10:54:43,772 [MainThread ] [INFO ] Document count: 931
2021-05-28 10:54:43,773 [MainThread ] [INFO ] Relation count: 642
2021-05-28 10:54:43,773 [MainThread ] [INFO ] Entity count: 2216
2021-05-28 10:54:43,773 [MainThread ] [INFO ] Updates per epoch: 3022
2021-05-28 10:54:43,773 [MainThread ] [INFO ] Updates total: 15110

[...]

Evaluation

--- Entities (named entity recognition (NER)) ---
An entity is considered correct if the entity type and span is predicted correctly

           type    precision       recall     f1-score      support
          brand         0.00         0.00         0.00        169.0
           drug         0.00         0.00         0.00       1396.0
         drug_n         0.00         0.00         0.00        101.0
          group         0.00         0.00         0.00        538.0

          micro         0.00         0.00         0.00       2204.0
          macro         0.00         0.00         0.00       2204.0

--- Relations ---

Without named entity classification (NEC)
A relation is considered correct if the relation type and the spans of the two related entities are predicted correctly (entity type is not considered)

           type    precision       recall     f1-score      support
         advise         0.00         0.00         0.00        130.0
      mechanism         0.00         0.00         0.00        253.0
         effect         0.00         0.00         0.00        250.0
            int         0.00         0.00         0.00          8.0

          micro         0.00         0.00         0.00        641.0
          macro         0.00         0.00         0.00        641.0

With named entity classification (NEC)
A relation is considered correct if the relation type and the two related entities are predicted correctly (in span and entity type)

           type    precision       recall     f1-score      support
         advise         0.00         0.00         0.00        130.0
      mechanism         0.00         0.00         0.00        253.0
         effect         0.00         0.00         0.00        250.0
            int         0.00         0.00         0.00          8.0

          micro         0.00         0.00         0.00        641.0
          macro         0.00         0.00         0.00        641.0

Process SpawnProcess-1:
Traceback (most recent call last):
File "/home/deeplearning/.conda/envs/salvalai/lib/python3.8/multiprocessing/process.py", line 315, in _bootstrap
self.run()
File "/home/deeplearning/.conda/envs/salvalai/lib/python3.8/multiprocessing/process.py", line 108, in run
self._target(*self._args, **self._kwargs)
File "/home/deeplearning/Salvalai/spert/spert.py", line 16, in __train
trainer.train(train_path=run_args.train_path, valid_path=run_args.valid_path,
File "/home/deeplearning/Salvalai/spert/spert/spert_trainer.py", line 97, in train
self._eval(model, validation_dataset, input_reader, epoch + 1, updates_epoch)
File "/home/deeplearning/Salvalai/spert/spert/spert_trainer.py", line 253, in _eval
evaluator.store_predictions()
File "/home/deeplearning/Salvalai/spert/spert/evaluator.py", line 87, in store_predictions
prediction.store_predictions(self._dataset.documents, self._pred_entities,
File "/home/deeplearning/Salvalai/spert/spert/prediction.py", line 196, in store_predictions
head_idx = converted_entities.index(converted_head)
ValueError: {'type': 'None', 'start': 0, 'end': 1} is not in list

Best regards

Hi,
I just pushed a corner case handling (commit e0d9aee) which may be related to your problem. In some cases (especially strings containing only control characters) the tokenizer we are using maps tokens to empty sequences, which could lead to zero divisions (and NaN values) down the road. Can you please check if the commit fixes your problem? If so, it would still be better to remove any control characters from your dataset beforehand.

If this does not fix your isuse, could you please send me the dataset (or a representative part of it) by email (markus.eberts@hs-rm.de)? I can have a look at it then.

Thank you, I'll check asap and let you know