salesforce/progen

Many repeats in Progen2 model predictions

amirshane opened this issue · 4 comments

I'm following the setup instructions and sampling from a few of the different models but the outputs are very repetitive.

For example:

python3 sample.py --model progen2-small --t 0.8 --p 0.9 --max-length 1024 --num-samples 1 --context "1"
outputs
1SPPPPPPGP2

python3 sample.py --model progen2-oas --t 0.8 --p 0.9 --max-length 1024 --num-samples 1 --context "1"
outputs
1MMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMM

python3 sample.py --model progen2-oas --t 0.8 --p 0.9 --max-length 1024 --num-samples 1 --context "1EVQ"
outputs
1EVQMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMM

python3 sample.py --model progen2-large --t 0.8 --p 0.9 --max-length 1024 --num-samples 1 --context "1"
outputs
1SEEFSFEEEWFFWMFALMEFFWWAFFEFWFEFEMWEEFFFEWFWFFFWFWFWEMLMFFFWWFEEEWEEEFFFEFWFWFWEEFMSFMWFEWWFWWFSEWFAFFEFEWWWWSSAFFFFSFFFFFFFWWFWFWFWFFFAFEFFWFEFWFWFEMWFMWFFFFFFFEWFWFFFFWFFEFWFFFWWEWEFFFFFWFFEEAFFWFFEFWFFESFSWEFEFFFWMEMFEWFFEFFEFFWFFWWFAFWWFFMEWFFFFFWFFFFFMWFMWFEWFFFFFFEFFWFFFFFFWMWFFWWMFSFFWFFAFFFFEWEFEFFAWFFEWFAAEAFFEFFFFFEFEFFFMWFWWWEAWMWEFSFFWFFFWWAFWFFWWESAFFSFFFFFWFFFWFFSFSWEEFFWAFFFWFAWFAFAWWMWEFFEFSFFWFFEEEFWFWFFFFMFFFEFWWFFFAEFMWFFFWEFWWEWFSFFWFWWFFWEFFFFWWAEFWFFAWEFALWWFFWWFFFEWWFFFAWFWWFFWEFFWFFFEFWSWFFWWAFFFWEFWSLFAMWFSEMSFAWFEFFWMWWFEFFFEEFFFFFFWWWEFMFFFFWFSLEEFFWEEFFFFFEFMFFWWFWWFWSFWWEEEWWFFEEWFWFFFWFWWFWFWSWEESWWAFFFWFESWWWWSAWSWEEFFWWFAWFFFMFFFWFFFFFWFFEWEEWSWFSWAWFWWWFWFWFWFWFWFFAWWWAEMWWFEWFWMAWWAWWFFEFEEFWFWWFEWWFFFFFWWWWEFALFFSEFAEWAFWEMWFFEFWSFMEEFFFAFFAEEMAFWWFEWSWWFFFFSFWMSFFWFWFFEFFWFWWSFWFMEFEFWEEFWWFMWFWFWFFWAFWWWFFWMWFWFFSWFWALFWSEFSEFFWFFFFFFFEFFFMFFFEFFFFWFWWFEWAFFFSFAFFWWWFEFFWFFWFFWWWFFEWWAMEFFFEWWAWFEFSFWSFAFAFEFEFFEWFWFWFEFWSFFFFFWFFFFFFWFWMEFMFFFEFFWEFSWFFFEFFFFAWFWFFEWMFFE

a-mad commented

we will investigate to make sure there isn't a bug in the code. but it's not surprising that the model can suffer from mode collapse and/or repetition issues. interestingly, the pretraining databases (e.g. uniref) are riddled with spurious artifacts as well FYI.

four simple hacks to solve for this:
(1) sample more. create a large library and simply filter out the poor generations
(2) better decoding strategies. implement a repetition penalty similar to the original ProGen or CTRL paper. huggingface should have support for this
(3) provide initial context to the model. even a few amino acids on either side of the sequence can greatly help
(4) tune the model. prompt-tuning a small number of extra parameters or finetuning to a curated dataset of proteins

@amirshane We have pushed a few corrections and sanity checks. The implementation should now be identical to the one used for the results reported in the paper.

Could you try once more? Does this help?

@amirshane had this issue too, but just created a loop and sampled with many different random seeds

magnesium-transport

progen2 -> alphafold2

thisproteindoesnotexist 😆

Thank you! @enijkamp @a-mad @lucidrains

I just reran the exact same commands and here are the outputs for reference:

python3 sample.py --model progen2-small --t 0.8 --p 0.9 --max-length 1024 --num-samples 1 --context "1"
outputs
1MTEQSLLERSFRKGLALLGLCCLASALGLAQAEPAEELAPEGFTLESVSVQGRTRYLRDIAIEEETVPLGALELDQRGRLEQAIARLPGVAINRTSGAPRIVTIQGRGLQSASTLYDGAPVTNPSNNGGAAVRFDAPISDVIEGTQGVVVDAATVGVTGFVQTTVAGSATGGRQGLVDAGYGSFDAESTLAYGANGWRAGVTGGYSHRSRFEQGGLYGSVYRAGVQSELSGSAGAGAAPLGPLEPGTVTREFLPGTAGVSPGTVPGPTSYGFVDQLTLRRDDRLVQDGVSFTWSGAYFDTERPLTPAAGLDASLAAGASGLFERAGTQEATTGLGASAEGRFAERSFDLTATVRATHEDGANSASSSGFVFAPNTGGGGSRSYRETAQGFLEARADTADAWQLDAGLEGSRFRADLAPREFATATRATVGALVGVPAAEGGRDIADPTLARRASLRDRSTRLSGTLRARWDADDWRATAGLTRTRVTDSAPLDARSLTGDGYRITDLSGFGSGRPSDGLTGAVDASLGRALAAGDGGALSASLGLRHSDRRLEAADGDAGVRGTLPGGGTRDIDRTSLAGDARLGLSRRTALTGTVGLSGQDSDLRAGTSVRVRDGATVGASVGLSGGGSGGDSGSASGSGSGSGSGSGSGSGSGSGSGSGSGSGSGSGSGSGSGSGSGSGSGSGSGSGSGSGSGSGSGSGSGSGSGSGSGSGSGSGSGSGSGSGSGSGSGSGSGSGSGSGSGSGSGSGSGSGSGSGSGSGSGSGSGSGSGSGSGSGSGSGSGSGSGSGSGSGSGSGSGSGSGSGSGSGSGSGSGSGSGSGSGSGSGSGSGSGSGSGSGSGSGSGSGSGSGSGSGSGSGSGSGSGSGSGSGSGSGSGSGSGSGSGSGSGSGSGSGSGSGSGSGSGSGSGSGSGSGSGSGSGSGSGSGSGSGSGSGSGSGSGSGSGSGSGSGSGSGSGSGSGSGSGSGSGSGSGSGSGSGSGSGSGSGSGSGSGSGSGSGSGSGSGSGSGSGSGSGSGSGSGSG

python3 sample.py --model progen2-oas --t 0.8 --p 0.9 --max-length 1024 --num-samples 1 --context "1"
outputs
1SETLSLTCTVSGGSISSYYWSWIRQPPGKGLEWIGYIYYSGSTNYNPSLKSRVTISVDTSKNQFSLKLSSVTAADTAVYYCARSLYDYVWGSYRYSDAFDIWGQGTMVTVSS2

python3 sample.py --model progen2-oas --t 0.8 --p 0.9 --max-length 1024 --num-samples 1 --context "1EVQ"
outputs
1EVQLVESGGGLVQPGGSLRLSCAASGFTFSSYAMSWVRQAPGKGLEWVSAISGSGGSTYYADSVKGRFTISRDNSKNTLYLQMNSLRAEDTAVYYCAKSGAPRRLVGATRGGDYWGQGTLVTVSS2

python3 sample.py --model progen2-large --t 0.8 --p 0.9 --max-length 1024 --num-samples 1 --context "1"
outputs
1MKENSFLERSFRKGLALLLTFTLALTLFSAGASPAYAADPEGGTLESVSEQTETTDSGDTAPEEETEEETAEEADVTEEEEEESAEDTEEDAEDTEEADATEEEEEDTAEDAEDAEDSEDTEEESEDTEDAEDTEDAEDTEDAEDTEDAEDTEDTEDTEDAEDTEDAEDTEDTEDTEDAEDTEDAEDTEDAEDTEDAEDTEDTEDTEDAEDTEDAEDTEDAEDTEDAEDTEDAEDTEDAEDTEDAEDTEDAEDTEDAEDTEDAEDTEDAEDTEDAEDTEDAEDTEDAEDTEDAEDTEDAEDTEDAEDTEDAEDTEDAEDTEDAEDTEDAEDTEDAEDTEDAEDTEDAEDTEDAEDTEDAEDTEDAEDTEDAEDTEDAEDTEDAEDTEDAEDTEDAEDTEDAEDTEDAEDTEDAEDTEDAEDTEDAEDTEDAEDTEDAEDTEDAEDTEDAEDTEDAEDTEDAEDTEDAEDTEDAEDTEDAEDTEDAEDTEDAEDTEDAEDTEDAEDTEDAEDTEDAEDTEDAEDTEDAEDTEDAEDTEDAEDTEDAEDTEDAEDTEDAEDTEDAEDTEDAEDTEDAEDTEDAEDTEDAEDTEDAEDTEDAEDTEDAEDTEDAEDTEDAEDTEDAEDTEDAEDTEDAEDTEDAEDTEDAEDTEDAEDTEDAEDTEDAEDTEDAEDTEDAEDTEDAEDTEDAEDTEDAEDTEDAEDTEDAEDTEDAEDTEDAEDTEDAEDTEDAEDTEDAEDTEDAEDTEDAEDTEDAEDTEDAEDTEDAEDTEDAEDTEDAEDTEDAEDTEDAEDTEDAEDTEDAEDTEDAEDTEDAEDTEDAEDTEDAEDTEDAEDTEDAEDTEDAEDTEDAEDTEDAEDTEDAEDTEDAEDTEDAEDTEDAEDTEDAEDTEDAEDTEDAEDTEDAEDTEDAEDTEDAEDTEDAEDTEDAEDTEDAEDTEDAEDTEDAEDTEDAEDTEDAEDTEDAEDTEDAEDTEDAEDTEDAEDTEDAEDTEDAEDTEDAEDTEDAEDTEDAEDTEDAEDTE

Thanks again for the changes. Closing the issue now.