snakers4/cft-contest-2018

Transformer based baseline

snakers4 opened this issue · 9 comments

  • Use OpenAI transformer + LM modelling model
  • Use 2 heads, write custom data generator
  • Write custom inference / evaluation loop - key ideas:
    • Pass the whole input to transformer
    • Take all position logits, go one by one left to right, until the first token is reached where the prediction is different
    • Change this token, repeat, until end of phrase is reached
  • Add / implement metric from the repo
  • Write a basic dataset class

@thinline72
Symbol distribution
All below 4k is obviously rubbish

char count
  1975506
А 1711205
В 1689781
О 1622053
И 1588955
Р 1447465
Н 1446088
Л 1249613
Е 1244502
М 1097509
Д 1002928
К 997736
Ч 951521
У 943420
С 929584
Т 879522
Б 762450
Х 605935
Г 565967
З 461391
Й 381082
Ш 374876
Ж 372056
Ь 327174
Я 325978
П 279990
Ф 220198
A 219131
Ы 210798
I 181841
O 180477
V 179476
R 174784
N 167101
Ю 144361
E 135067
L 127865
M 126047
U 124653
S 123546
H 122738
D 106453
K 102485
T 98554
Y 89749
B 88956
Э 83410
G 65340
C 64942
Z 61941
Ц 58751
Ё 47198
J 42915
Щ 31887
F 28175
X 27009
P 26159
Ъ 17875
"-" 13455
Q 12121
W 6394
. 6258
' 4122
І 21
` 6
Ї 3
_ 2
а 1
р 1
б 1
н 1
и 1
х 1
д 1
5 1

The test contains these symbols
This is retarded
default

Discrepancy analysis

Mostly superficial
default

default

Max seq len is ~100
Seems like a worthy task for transformer xD

default

top countries
ideally they should be matched manually

default

Basic launch script

CUDA_VISIBLE_DEVICES=2 python3 pytorch.train_transformer.py \
	--tensorboard True --tb_name tr_baseline_adam --csv_log tr_runs.csv \
	--optimizer adam \
	--batch_size 2000 --num_workers 2 \
	--lr 1e-4 --freeze_emb False \
	--emb_size 50 \
	--add_cn_embeddings False \
	--n_head 2 --n_layer 2 \
	--embd_pdrop 0.1 --attn_pdrop 0.1 --resid_pdrop 0.1 --clf_pdrop 0.1 \
	--act_fn gelu \
	--epochs 100 \

Some useful code for EDA

from IPython.display import Markdown

def style_text(text,
               colour):
    return '<span style="background-color: {}">{}</span>'.format(colour,text)

def print_style(marked_text):
    display(Markdown(marked_text))
    
def process_text(name,gt_name):
    print('Name {} / GT Name {}'.format(len(name),len(gt_name)))
    max_len = max(len(name),len(gt_name))
    
    if len(name)<max_len:
        name += ' '*(max_len-len(name))
        
    if len(gt_name)<max_len:
        gt_name += ' '*(max_len-len(gt_name))
    
    name = list(name)+['☠']
    gt_name = list(gt_name)+['☠']
    
    for i in range(0,max_len):
        # just a dumb styling
        if name[i]!=gt_name[i]:
            
            # deleted character:
            if name[i]==gt_name[i+1]:
                name.insert(i, style_text('  ','#ff5050'))
            # inserted extra character
            elif name[i+1]==gt_name[i]:
                name[i] = style_text(name[i],'#ff5050')
                gt_name.insert(i, style_text(' ','#3366ff'))
            else:
                # character swap
                name[i] = style_text(name[i],'#ff5050')
    
    print_style(''.join(name))
    print_style(''.join(gt_name))    

Updated error viz
Covers almost all errors now

from IPython.display import Markdown

def style_text(text,
               colour):
    return '<span style="background-color: {}">{}</span>'.format(colour,text)

def print_style(marked_text):
    display(Markdown(marked_text))
    
def process_text(name,gt_name,is_print=False):
    if type(gt_name) == float:
        return 0,0
    if is_print:
        print('Name {} / GT Name {}'.format(len(name),len(gt_name)))
    
    max_len = max(len(name),len(gt_name))
    
    if len(name)<max_len:
        name += ' '*(max_len-len(name))
        
    if len(gt_name)<max_len:
        gt_name += ' '*(max_len-len(gt_name))
    
    name = list(name)+['☠']
    gt_name = list(gt_name)+['☠']
    error_counter = 0
    error_list = []
    
    i =0 
    while i < max_len:
        # just a dumb styling
        if name[i]!=gt_name[i]:
            
            # character swap
            if name[i]==gt_name[i+1] and name[i+1]==gt_name[i]:
                name[i] = style_text(name[i],'#ff5050')
                name[i+1] = style_text(name[i+1],'#ff5050')
                error_list.append('swap')
                i += 1
            # deleted character:
            elif name[i]==gt_name[i+1]:
                name.insert(i, style_text('☐','#ff5050'))
                error_counter+=1
                error_list.append('del')
            # inserted extra character
            elif name[i+1]==gt_name[i]:
                name[i] = style_text(name[i],'#ff5050')
                gt_name.insert(i, style_text('☐','#3366ff'))
                error_counter+=1
                error_list.append('extra')
            else:
                # wtf?
                name[i] = style_text(name[i],'#ff5050')
                error_counter+=1
                error_list.append('extra')
        i += 1
    if is_print:
        print('Errors {}'.format(error_counter))
        print_style(''.join(name))
        print_style(''.join(gt_name))
    
    return error_counter,error_list
def get_errors(name,gt_name):
    max_len = max(len(name),len(gt_name))
    
    if len(name)<max_len:
        name += ' '*(max_len-len(name))

    if len(gt_name)<max_len:
        gt_name += ' '*(max_len-len(gt_name))    
    
    i = 0
    
    name = list(name)+['☠']
    gt_name = list(gt_name)+['☠']    
    
    error_positions = []
    error_corrections = []
    
    while i < max_len:
        # just a dumb styling
        if name[i]!=gt_name[i]:    
            # character swap
            try:
                if name[i]==gt_name[i+1] and name[i+1]==gt_name[i]:
                    error_positions.append(i)
                    error_corrections.append('_swap_')
                    # pass the next character as well
                    i += 1

                # deleted character:
                elif name[i]==gt_name[i+1]:
                    error_positions.append(i)
                    error_corrections.append('_insert_')
                    name.insert(i,' ')

                # inserted extra character
                elif name[i+1]==gt_name[i]:
                    error_positions.append(i)
                    error_corrections.append('_del_') 
                    gt_name.insert(i,' ')
                else:
                    # character replace
                    error_positions.append(i)
                    error_corrections.append(gt_name[i])
            except:
                pass
        i += 1
                
    return error_positions,error_corrections