Implementing Restricted Access Sequence Processing (RASP) transformer language from "Thinking like Transformers" paper. From the paper:
RASP can be used to program solutions to tasks that could conceivably be learned by a Transformer, and how a Transformer can be trained to mimic a RASP solution.
Simple pip install. pip install git+https://github.com/yashbonde/rasp
select
(creating selection matrices called selectors): This is similar to attention creationk x q.T => attn
.aggregate
(collapsing selectors and s-ops into a new s-ops): This is like value multiplicationattn x v
.selector_width
(creating an s-op from a selector): returns number of activated tensorsselector_width(select(tokens,tokens,==))("hello")=[1,1,2,2,1]
.
rasp/
: main code librarytest/
: testsprimitives
: neural functions and primitives built using manual codes that can be used for training
So you can built complex flows directly in terms of architecture ex. building reverse function:
from rasp import RaspModule
reverse = RaspModule('''
def reverse(tokens):
opp_idx = length - indices - 1;
flip = select (indices ,opp_index ,==) ;
return aggregate (flip, tokens);
''')
assert reverse("hey") == "yeh"
This would create a neural network as follows:
class Flip(nn.Module):
# flip = select (indices ,opp_index ,==) ;
def __init__(self):
self.n_head = 1;
def forward(self):
pass
All the code for tests are given in test/
, run pytest -v
.
- Reverse e.g.:
reverse("abc")="cba"
- Histograms, with a unique beginning-of-sequence (BOS) token
$
(e.g.,hist_bos("$aba")=[$,2,1,2]
) and without it (e.g.,hist_nobos("aba")=[2,1,2]
) - Double-Histograms, with BOS: for each token, the number of unique tokens with same histogram value as itself. E.g.:
hist2("$abbc")=[§,2,1,1,2]
- Sort, with BOS: ordering the input tokens lexicographically. e.g.:
sort("$cba")="$abc".
- Most-Freq, with BOS: returning the unique input tokens in order of decreasing frequency, with original position as a tie-breaker and the BOS token for padding. E.g.:
most_freq("$abbccddd")="$dbca$$$$"
- Dyck-i PTF, for
i = 1, 2
: the task of returning, at each output position, whether the input prefix up to and including that position is a legal Dyck-i sequence (T
), and if not, whether it can (P
) or cannot (F
) be continued into a legal Dyck-i sequence. E.g:Dyck1_ptf("()())")="PTPTF"