gaussalgo/adaptor

`TextClassification` objective sometimes misaligns tokens and wordpieces

Closed this issue · 1 comments

The following code indicates that TextClassification sometimes fails to align tokens and wordpieces:

>>> text = 'Ce. Jif. Cebivskym z Záduba a Kat. z Hostouné о odkaz . . Janem z Pernsteina a Boh. Cernfnem o rukojemstvi . Lad. z Boskovic a Boh. Cerninem o rukojemstvi . . Han. Trojem kupcem a Mat, Libikem z Radovesic o réeni . Jindř., Kulem z Véfic a Zikm. Pétipeskym z Kr. Dvoru o koné . Mik. postiihaem z Budějovic a Jindř. Sudlicem z Jivovice o dluh Jiř. z Puchova a Kun. Sertyngrem z Sertynge o koné . Krist. Talkenberkem a Janem Blektou o vložení včna do desk Jif. ze Stranec a Katef. Kozlovou o základ propadeny . . , Arn. z Drasova a Linh, Nekáem z Landeku o jistinu . Mik. Kouhou a Václ. z Dédibab o dluh . 654 2087. 2088. 2089. 2090. 2091. 2092. 2093. 2094. 2095. 2096. 2097. 2098. 2099. 2100. 2101. 2102. 2103. 2104. 2105. 2106. 2107. 2108. 2109. 2110. 2111. 2112. 2113. 2114. 2115. 2116. 2117. 2118. 2119. 2120. 2121. 2122. 2123. 2124. 2125. 2126. 2127. 2128. 2129. 2130. 2131. 2132. 2133. 2134. 2135. 2136. . list, . list. . list. . list. . list, . list. . list. . dub. . dub.'
>>> labels = 'O O O O O O O O O O O O O O O O O O O O O O O O LOC O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O'
>>>
>>> assert len(text.split()) == len(labels.split())
>>>
>>> from adaptor.lang_module import LangModule
>>> from adaptor.objectives.classification import TokenClassification
>>> 
>>> lang_module = LangModule('xlm-roberta-base')
>>> objective = TokenClassification(lang_module, batch_size=1, texts_or_path=[text], labels_or_path=[labels])
>>> list(objective._wordpiece_token_label_alignment([text], [labels]))

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/***/adaptor/objectives/classification.py", line 52, in _wordpiece_token_label_alignment
    next_token = tokens[0]

IndexError: list index out of range

After removing all tokens (and corresponding labels) that only contain non-alphanumeric characters (re.fullmatch(r'\W+', token, re.UNICODE)), the problem disappears, which indicates that the sole . and , tokens might be the source of the issue.

We can guard against the IndexError by setting the last artificial token to None instead of wordpieces[-1]:

# next token lookup - avoid out-of-index, and exclude from token labels
tokens.append(wordpieces[-1])
labels.append("O")
This should ensure that we will never consume the last articifial token by accident if we somehow get ahead of wordpieces during the alignment. Of course, we should still investigate why the misalignment happens and fix it, so that we don't silently feed garbage to the model!

Tasks

  • Fix and unit-test the misalignment.
  • Make last artificial token in _wordpiece_token_label_alignment() unconsumable.

@stefanik12 Sorry for the lack of the M in my MWE.