Transformer based baseline
snakers4 opened this issue · 9 comments
snakers4 commented
- Use OpenAI transformer + LM modelling model
- Use 2 heads, write custom data generator
- Write custom inference / evaluation loop - key ideas:
- Pass the whole input to transformer
- Take all position logits, go one by one left to right, until the first token is reached where the prediction is different
- Change this token, repeat, until end of phrase is reached
- Add / implement metric from the repo
- Write a basic dataset class
snakers4 commented
@thinline72
Symbol distribution
All below 4k is obviously rubbish
char | count |
---|---|
1975506 | |
А | 1711205 |
В | 1689781 |
О | 1622053 |
И | 1588955 |
Р | 1447465 |
Н | 1446088 |
Л | 1249613 |
Е | 1244502 |
М | 1097509 |
Д | 1002928 |
К | 997736 |
Ч | 951521 |
У | 943420 |
С | 929584 |
Т | 879522 |
Б | 762450 |
Х | 605935 |
Г | 565967 |
З | 461391 |
Й | 381082 |
Ш | 374876 |
Ж | 372056 |
Ь | 327174 |
Я | 325978 |
П | 279990 |
Ф | 220198 |
A | 219131 |
Ы | 210798 |
I | 181841 |
O | 180477 |
V | 179476 |
R | 174784 |
N | 167101 |
Ю | 144361 |
E | 135067 |
L | 127865 |
M | 126047 |
U | 124653 |
S | 123546 |
H | 122738 |
D | 106453 |
K | 102485 |
T | 98554 |
Y | 89749 |
B | 88956 |
Э | 83410 |
G | 65340 |
C | 64942 |
Z | 61941 |
Ц | 58751 |
Ё | 47198 |
J | 42915 |
Щ | 31887 |
F | 28175 |
X | 27009 |
P | 26159 |
Ъ | 17875 |
"-" | 13455 |
Q | 12121 |
W | 6394 |
. | 6258 |
' | 4122 |
І | 21 |
` | 6 |
Ї | 3 |
_ | 2 |
а | 1 |
р | 1 |
б | 1 |
н | 1 |
и | 1 |
х | 1 |
д | 1 |
5 | 1 |
snakers4 commented
snakers4 commented
snakers4 commented
snakers4 commented
snakers4 commented
Basic launch script
CUDA_VISIBLE_DEVICES=2 python3 pytorch.train_transformer.py \
--tensorboard True --tb_name tr_baseline_adam --csv_log tr_runs.csv \
--optimizer adam \
--batch_size 2000 --num_workers 2 \
--lr 1e-4 --freeze_emb False \
--emb_size 50 \
--add_cn_embeddings False \
--n_head 2 --n_layer 2 \
--embd_pdrop 0.1 --attn_pdrop 0.1 --resid_pdrop 0.1 --clf_pdrop 0.1 \
--act_fn gelu \
--epochs 100 \
snakers4 commented
Some useful code for EDA
from IPython.display import Markdown
def style_text(text,
colour):
return '<span style="background-color: {}">{}</span>'.format(colour,text)
def print_style(marked_text):
display(Markdown(marked_text))
def process_text(name,gt_name):
print('Name {} / GT Name {}'.format(len(name),len(gt_name)))
max_len = max(len(name),len(gt_name))
if len(name)<max_len:
name += ' '*(max_len-len(name))
if len(gt_name)<max_len:
gt_name += ' '*(max_len-len(gt_name))
name = list(name)+['☠']
gt_name = list(gt_name)+['☠']
for i in range(0,max_len):
# just a dumb styling
if name[i]!=gt_name[i]:
# deleted character:
if name[i]==gt_name[i+1]:
name.insert(i, style_text(' ','#ff5050'))
# inserted extra character
elif name[i+1]==gt_name[i]:
name[i] = style_text(name[i],'#ff5050')
gt_name.insert(i, style_text(' ','#3366ff'))
else:
# character swap
name[i] = style_text(name[i],'#ff5050')
print_style(''.join(name))
print_style(''.join(gt_name))
snakers4 commented
Updated error viz
Covers almost all errors now
from IPython.display import Markdown
def style_text(text,
colour):
return '<span style="background-color: {}">{}</span>'.format(colour,text)
def print_style(marked_text):
display(Markdown(marked_text))
def process_text(name,gt_name,is_print=False):
if type(gt_name) == float:
return 0,0
if is_print:
print('Name {} / GT Name {}'.format(len(name),len(gt_name)))
max_len = max(len(name),len(gt_name))
if len(name)<max_len:
name += ' '*(max_len-len(name))
if len(gt_name)<max_len:
gt_name += ' '*(max_len-len(gt_name))
name = list(name)+['☠']
gt_name = list(gt_name)+['☠']
error_counter = 0
error_list = []
i =0
while i < max_len:
# just a dumb styling
if name[i]!=gt_name[i]:
# character swap
if name[i]==gt_name[i+1] and name[i+1]==gt_name[i]:
name[i] = style_text(name[i],'#ff5050')
name[i+1] = style_text(name[i+1],'#ff5050')
error_list.append('swap')
i += 1
# deleted character:
elif name[i]==gt_name[i+1]:
name.insert(i, style_text('☐','#ff5050'))
error_counter+=1
error_list.append('del')
# inserted extra character
elif name[i+1]==gt_name[i]:
name[i] = style_text(name[i],'#ff5050')
gt_name.insert(i, style_text('☐','#3366ff'))
error_counter+=1
error_list.append('extra')
else:
# wtf?
name[i] = style_text(name[i],'#ff5050')
error_counter+=1
error_list.append('extra')
i += 1
if is_print:
print('Errors {}'.format(error_counter))
print_style(''.join(name))
print_style(''.join(gt_name))
return error_counter,error_list
snakers4 commented
def get_errors(name,gt_name):
max_len = max(len(name),len(gt_name))
if len(name)<max_len:
name += ' '*(max_len-len(name))
if len(gt_name)<max_len:
gt_name += ' '*(max_len-len(gt_name))
i = 0
name = list(name)+['☠']
gt_name = list(gt_name)+['☠']
error_positions = []
error_corrections = []
while i < max_len:
# just a dumb styling
if name[i]!=gt_name[i]:
# character swap
try:
if name[i]==gt_name[i+1] and name[i+1]==gt_name[i]:
error_positions.append(i)
error_corrections.append('_swap_')
# pass the next character as well
i += 1
# deleted character:
elif name[i]==gt_name[i+1]:
error_positions.append(i)
error_corrections.append('_insert_')
name.insert(i,' ')
# inserted extra character
elif name[i+1]==gt_name[i]:
error_positions.append(i)
error_corrections.append('_del_')
gt_name.insert(i,' ')
else:
# character replace
error_positions.append(i)
error_corrections.append(gt_name[i])
except:
pass
i += 1
return error_positions,error_corrections