[Critical] Very high loss rate at first few tokens (classifier free guidance not working)
MarcusLoppe opened this issue ยท 66 comments
@lucidrains
This is a issue I'm having a while, the cross-attention is very weak at the start of the sequence.
When the transformer starts with no tokens it will relay on the cross-attention but unfortunately the cross-attention doesn't work for the first token(s).
Proof
To prove this I trained a dataset of 500 models that have unique text embeddings and no augmentations, then I only took the first 6 tokens of the mesh and train on that.
After training for 8hrs, it's still stuck at 1.03 loss.
Without fixing this issue, the auto-regression without a prompt of tokens will never work.
This problem has been ongoing for a while but I thought it was a issue of training and using a model that has been trained on the first few tokens would resolve this. However that isn't the case.
Real-life example
To highlight the issue, I trained a model on the 13k dataset then removed all the augmentation copies and removed models with duplicate labels.
If I provide it with the first 2 tokens as a prompt it will autocomplete without no problem and no visual issues, however if i provide it with 1 or 0 tokens it fails completely.
Checked the logits
I investigated this further and checked the logits when it generated the first token, the probability for correct token was at the 9th most probable token.
I tried to implement a beam search with beam width of 5 but since the first token has such a low probability, it would require a lot of beams which probably will work but this seems like a brute force solution isn't very good.
It may work to do a beam search of 20 and then kill of the solutions which seems to have a low probability/entropy, but this seems like a bandage solution that might not work with scaling up meshgpt.
Why is this a problem?
The first tokens are very important for the generation since it's a domino effect, if it gets the incorrect token at the start, the generation will fail since it relays to much on the sequence to auto-correct.
It's like if the sentence is "Dog can be a happy animal" and it predicts "Human" as the first token, it won't be able to auto-correct since sentence is already messed up and the chances it will auto-correct to "Human got a dog which can be a happy animal" is extremely hard.
Possible solution
Since the cross-attention is used only on the "big" decoder, can it also be implemented for the fine decoder?
Attempts to fix:
- I've tried removing the fine decoder and fine gateloop
- I also tried increasing cross_attn_num_mem_kv but found no signifiant changes.
- I replaced theTextEmbeddingReturner with AttentionTextConditioner but still no changes.
- Tried using different text encoder such as BGE and CLIP.
This has been a problem for a long time and I've mentioned in the issues threads as a note so I'm creating a issue for it since it really prevents me from releasing fine-tuned models.
I got a model ready to go that can predict 13k models but since the first tokens make the autoregressive generation makes it impossible, I've not released it yet.
This sounds critical indeed. Hopefully it's an easy fix.
I think I've resolved this issue by tokenizing the text and insert it at the start of the codes and add a special token to indicate the start of the mesh tokens.
However the downside with this is that the transformer needs to use a larger vocab, any idea how if it's possible to reduce the vocab size it's predicting for?
I tested it on a smaller dataset but it seems to be working!
I think this will also guide the transformer much better.
@MarcusLoppe That is fantastic! Have you posted the fix somewhere?
@MarcusLoppe That is fantastic! Have you posted the fix somewhere?
Not yet, my current way is bit hacky and requires bit of a rewrite to properly implement.
I'm currently verifying the solution on bit bigger dataset and will hammer out all the possible bugs.
@MarcusLoppe hey Marcus, thanks for identifying this issue
have you tried turning off CFG? if you haven't, one thing to try is simply turning off CFG for the first few tokens. i think i've come across papers where they studied at which steps CFG is even effective
also try turning off CFG and do greedy sampling and see what happens. if that doesn't work, there is some big issue
@MarcusLoppe hey Marcus, thanks for identifying this issue
have you tried turning off CFG? if you haven't, one thing to try is simply turning off CFG for the first few tokens. i think i've come across papers where they studied at which steps CFG is even effective
also try turning off CFG and do greedy sampling and see what happens. if that doesn't work, there is some big issue
With CFG you mean classifier-free guidance?
Not sure how I would go about that, do you mean setting cond_drop_prob to 0.0?
I've tried that and as far as I can tell the CFG just returns the embedding without any modifications (if cond_drop_prob is set to 0 since then it won't mask the text embedding).
The issue lies with when the transformer has a empty sequence and only the text embedding to go from. The text embedding doesn't seem to help very much so it doesn't know what token to pick, hence the huge loss at the start.
@MarcusLoppe oh, maybe it is already turned off
so CFG is turned on by setting cond_scale
> 1. when invoking .generate
if you haven't been using cond_scale
, then perhaps it was never turned on
@MarcusLoppe oh crap, do i only append the start token for the fine transformer?? ๐คฆ yes you are correct, it is never conditioned then for the first set of fine tokens
thank you, this is a real issue then. i'll add cross attention to the fine transformer later today
edit: another cheap solution would be to project text embeddings, pool it, and just use that as the initial sos token
@MarcusLoppe oh crap, do i only append the start token for the fine transformer?? ๐คฆ yes you are correct, it is never conditioned then for the first set of fine tokens
Awesome, however my 'fix' seems to be working however.
By provide the text in the form of tokens in the sequence the fine-decoder will get the text context and it also helps creating a stronger relationship with the tokens and speed up the training.
So the tokens it trains on is like: "chair XXXXXXXX" (where X is the mesh tokens).
The downside is that it needs a bigger vocab which slows the training bit but the stronger relationship between the mesh tokens and the text seems to be working :)
thank you, this is a real issue then. i'll add cross attention to the fine transformer later today
I had some issues with proving the context to the fine-decoder since the vector changes shapes but you might be able to solve it.
However I tried removing the gateloop and fine-decoder so the main decoder is the last layer, but unfortunately it had the same issue.
@MarcusLoppe yup, your way is also legit ๐
you have a bright future Marcus! finding this issue, the analysis, coming up with your own solution; none of that is trivial
@MarcusLoppe yup, your way is also legit ๐
you have a bright future Marcus! finding this issue, the analysis, coming up with your own solution; none of that is trivial
Thank you very much ๐ Although it took a while I think I've learned one or two things on the way ๐
thank you, this is a real issue then. i'll add cross attention to the fine transformer later today
edit: another cheap solution would be to project text embeddings, pool it, and just use that as the initial sos token
I don't think the cross-attention will be enough, as per my last reply i removed the fine-decoder and gateloop and had the same issue.
If you think about the multimodal generative models they never start from token 0. For a example the vision models has a prompt with a specific request from the user.
So it has the first few tokens and some sort of goal or idea what to generate, then the cross-attention will do it's job and provide the addition context.
So the generative has a more 'probabilistic path' start to get to the correct answer.
I think projecting the text embeddings might be the better way in this case.
@MarcusLoppe yup i went with the pooled text embedding summed to the sos token for now
let me know if that fixes things (or not) ๐ค
this was really my fault for designing the initial architecture incorrectly
the sos token should be on the coarse transformer
@MarcusLoppe yup i went with the pooled text embedding summed to the sos token for now
let me know if that fixes things (or not) ๐ค
Awesome! I'll check it out ๐
However with the last x-transformers update I'm getting the error below.
The num_mem_kv doesn't seem to be picked up or trimmed by:
"attn_kwargs, kwargs = groupby_prefix_and_trim('attn_', kwargs)"
And the dim_head in meshgpt isn't being passed correctly as it should be: "attn_dim_head "
-> 1057 assert len(kwargs) == 0, f'unrecognized kwargs passed in {kwargs.keys()}'
1059 dim_head = attn_kwargs.get('dim_head', DEFAULT_DIM_HEAD)
1061 self.dim = dim
AssertionError: unrecognized kwargs passed in dict_keys(['dim_head', 'num_mem_kv'])
@MarcusLoppe ah yes, those should have attn_
prepended, should be fixed in the latest version
Alright here is some results.
Using the CLIP embedding model (higher distances in the embedding space) with a GPT-small size transformer:
I first trained using a small set of 350 models, which have a total of x5 augments each. It only contains 39 unique labels so there are some overlap with the texts.
Previous test just produced a blob of triangles, this time it outputted all tents and a blob.
I then took the same model and removed all augmentations so it's x1 of each model and unique texts for each model.
This outputted somewhat better results but it's still not following the text guidance.
I checked the logits and the first token generate was for a bench model and the correct was at the 19th placement and had the value 0.013.
And as you can see, the loss at the start didn't show any improvements :/
For sanity check I trained a fresh model on the x1 to 0.004 loss but as you can see it didn't help. Might made it worse.
I did the same test previously using my method with tokenized text I was able to get all perfect results using the x1 (did not test x5), so that would indicate that the issue that the cross attention relationship when there is no tokens isn't strong enough.
Btw I tested just adding fake tokens by increasing the codebook and used e.g codebook_size +1 (eos at +2) at the start but that didn't change anything.
@MarcusLoppe ok, i threw in cross attention conditioning for fine transformer in 1.2.3
if that doesn't work, i'll just spend some time refactoring so the start token is in the coarse transformer. that would work for sure, as it is equivalent to your solution, except the text tokens do not undergo self attention
@MarcusLoppe thanks for running the experiments!
@MarcusLoppe ok, i threw in cross attention conditioning for fine transformer in 1.2.3
if that doesn't work, i'll just spend some time refactoring so the start token is in the coarse transformer. that would work for sure, as it is equivalent to your solution, except the text tokens do not undergo self attention
The loss rate improved much better over the epochs, however it had some downside.
Before it generate 100 tokens/s, now it went down to 80 t/s, but I prefer this version much more I think since this will cut down the training speed.
Since inference time increased so did the per epoch, using a 2k dataset it went from 02:28 to 02:42, however I saw better loss improvements.
Unfortunately it did not work :(
However something to note is that it worked before using the demo mesh dataset that consist of 9 meshes.
@MarcusLoppe ah, thank you
ok, final try
will have to save this for late next week if it doesn't work
@MarcusLoppe ah, thank you
ok, final try
will have to save this for late next week if it doesn't work
It worked better, here is the result of training it on 39 models with unique labels, however you can still see a spike in the start of the sequence meaning that it might not be resolved.
Using my method I managed to get these results below, it manages to generate quite complex objects.
However the start is still bit weak, it would help if you manage to make it so the sos token is in the coarse transformer, this will help the training time a lot since it can reduce the vocab size from 32k to 2k :)
I've also experiment with using 3 tokens per triangle and the autoencoder seems to be working, however it makes the training progression for the transformer slower. But considering that VRAM requirement for training on 800 triangle meshes would go from 22GB to 9GB and half the generation time, I think that is something worth exploring.
However I think that the autoencoder could also benefit from getting the text embeddings, I tried to pass it as the context in the linear attention layer but since it requires the the same shape as the quantized input it won't accept it nor I think it would be very VRAM friendly to duplicate the text embedding to the number of faces.
Do you think there is a easy fix for this? I think it would reduce the codebook size a lot and help create codes with closer relationships to the text which would benefit the transformer a lot.
@MarcusLoppe that is much better (an order of magnitude)! thank you for the experiments Marcus!
i know how to improve it (can add multiple sos to give the attention more surface area)
@MarcusLoppe i'll get back to this later this week ๐
@MarcusLoppe oh, the sos token has already been moved to the coarse transformer in the latest commit. that's where the improvement you are seeing is coming from
@MarcusLoppe that is much better (an order of magnitude)! thank you for the experiments Marcus!
i know how to improve it (can add multiple sos to give the attention more surface area)
Oh awesome, however the loss got very low (0.006) for these results, for the bigger datasets the loss gets to about 0.01 until it needs like 1000 epochs to reach similar loss.
So some further improvements would be nice! ๐
Any thoughts about the text embedding aware auto-encoder?
@MarcusLoppe yup, we can try multiple sos tokens, then if that doesn't work, i'll build in the option to use prepended text embeddings (so like the solution you came up with, additional sos excised or pooled before fine transformer)
and yes, text embedding aware autoencoder is achievable! in fact, the original soundstream paper did this
@MarcusLoppe yup, we can try multiple sos tokens, then if that doesn't work, i'll build in the option to use prepended text embeddings (so like the solution you came up with, additional sos excised or pooled before fine transformer)
and yes, text embedding aware autoencoder is achievable! in fact, the original soundstream paper did this
Alright, I've tested the latest patch.
I tested using sos tokens in the amount of: 1,2,4,8, 16, however I was unable to get any usable meshes from it.
I sanity checked by reverting to the previous commit and was able to generate valid mesh.
To generate the mesh I tested setting the cond_scale to 3, turned off the kv_cache but they just outputted a blob.
However the loss definitely is smoothed over, as you an see the loss doesn't sticks up as it did before (I've might had a too small sample size so ignore the loss values).
I think the issue might be that when they get averaged together they lose their meaning or they get to complex to understand.
Also the below prevented me from testing using 1 sos token, since it get's packed but never unpacked.
if exists(cache):
cached_face_codes_len = cached_attended_face_codes.shape[-2]
cached_face_codes_len_without_sos = cached_face_codes_len - 1
need_call_first_transformer = face_codes_len > cached_face_codes_len_without_sos
else:
# auto prepend sos token
sos = repeat(self.sos_token, 'n d -> b n d', b = batch)
face_codes, packed_sos_shape = pack([sos, face_codes], 'b * d')
# if no kv cache, always call first transformer
need_call_first_transformer = True
if need_call_first_transformer:
if exists(self.coarse_gateloop_block):
face_codes, coarse_gateloop_cache = self.coarse_gateloop_block(face_codes, cache = coarse_gateloop_cache)
..............
..............
if not exists(cache) and self.num_sos_tokens > 1:
sos_tokens, attended_face_codes = unpack(attended_face_codes, packed_sos_shape, 'b * d')
pooled_sos_token = reduce(sos_tokens, 'b n d -> b 1 d', 'mean')
attended_face_codes = torch.cat((pooled_sos_token, attended_face_codes), dim = 1)
Previous commit ( 34f2806)
New commit:
@MarcusLoppe ok, let's go with your intuition and just grab the last sos token
@MarcusLoppe ok, let's go with your intuition and just grab the last sos token
Don't listen to me :) I think you are onto something, I don't think that it's possible for all the nuances in a text can be contained in a single token. As you can see, the loss is smoother and not sticking up, so it did something right.
Do you have any good reason why you used mean pooling?
Otherwise I'll do some testing with replacing it with some attention layer
oh I actually kept the multiple sos tokens, but listened to your suggestion not to use mean pooling, and instead grab the last sos token to forward to fine transformer
was just reading a paper claiming that turning off CFG for earlier tokens leads to better results https://arxiv.org/html/2404.13040v1 should get this into the CFG repo at some point ๐ค
oh I actually kept the multiple sos tokens, but listened to your suggestion not to use mean pooling, and instead grab the last sos token to forward to fine transformer
Oh alright, well I've haven't checked out that patch yet but the attention work really good! ๐ ๐
I trained on 350 models and it seems like it have a very strong text conditional.
The only change I did was to add a simple linear layer! :) This might have been the last step of meshgpt! (not counting the last 10 "last steps")
I'll get back to you with the results of just using the last sos token.
Using the texts: 'bed', 'sofa', 'monitor', 'bench', 'chair', 'table'
I generate them 3 times, first row is with temp at 0.0 and 0.5 for the others.
As you can see, it has a very strong text relationship :)
@MarcusLoppe awesome! did you mean that you used an extra linear on the mean pooled sos tokens?
@MarcusLoppe awesome! did you mean that you used an extra linear on the mean pooled sos tokens?
Hey, so I added a Linear and pooled the tokens as below:
self.attention_weights = nn.Linear(dim, 1, bias=False)
attention_scores = F.softmax(self.attention_weights(sos_tokens), dim=1)
pooled_sos_token = torch.sum(attention_scores * sos_tokens, dim=1, keepdim=True)
It worked very good when I dealt with 350 objects but when I scaled up to 2k objects it didn't work as good ๐ญ
I tried implementing more then one layer but it got worse when I made it too complex.
Also I'm not quite sure about doing mean over the tokens since that might promote similar tokens.
Do you have any other idea then to prepend the tokenized text? ๐ It really hurts the performance if it needs to predict over 66k (50k+16k) tokens.
Here is some results, I took 355 samples and for each sample I found other items in the dataset that had the same type of label e.g. "chair" and "tall chair". This would give me a good idea about how well it keeps to the text condition.
As you can see, the attention one had the best results using 2 or 4 tokens, however the commit you did before also had some good results.
Latest commit using 4 sos tokens try not pooling all the sos tokens and instead using the last one
Token 1: Correct = 169, Incorrect = 186
Token 2: Correct = 234, Incorrect = 121
Token 3: Correct = 353, Incorrect = 2
Token 4: Correct = 354, Incorrect = 1
The commit before the last one move the concatenation of the sos token so it is always
Token 1: Correct = 258, Incorrect = 97
Token 2: Correct = 215, Incorrect = 140
Token 3: Correct = 354, Incorrect = 1
Token 4: Correct = 354, Incorrect = 1
All below, attention linear layer:
1 num_sos_tokens:
Token 1: Correct = 74, Incorrect = 281
Token 2: Correct = 205, Incorrect = 150
Token 3: Correct = 352, Incorrect = 3
Token 4: Correct = 354, Incorrect = 1
2 num_sos_tokens:
Token 1: Correct = 237, Incorrect = 118
Token 2: Correct = 220, Incorrect = 135
Token 3: Correct = 353, Incorrect = 2
Token 4: Correct = 354, Incorrect = 1
4 num_sos_tokens:
Token 1: Correct = 237, Incorrect = 118
Token 2: Correct = 205, Incorrect = 150
Token 3: Correct = 351, Incorrect = 4
Token 4: Correct = 354, Incorrect = 1
8 num_sos_tokens:
Token 1: Correct = 126, Incorrect = 229
Token 2: Correct = 220, Incorrect = 135
Token 3: Correct = 351, Incorrect = 4
Token 4: Correct = 354, Incorrect = 1
4 num_sos_tokens using mean:
pooled_sos_token = reduce(sos_tokens, 'b n d -> b 1 d', 'mean')
attention_scores = F.softmax(self.attention_weights(pooled_sos_token ), dim=1)
pooled_sos_token = torch.sum(attention_scores * sos_tokens, dim=1, keepdim=True)
Token 1: Correct = 158, Incorrect = 197
Token 2: Correct = 214, Incorrect = 141
Token 3: Correct = 344, Incorrect = 11
Token 4: Correct = 350, Incorrect = 5
Token 5: Correct = 355, Incorrect = 0
@MarcusLoppe awesome! did you mean that you used an extra linear on the mean pooled sos tokens?
Hey, so I added a Linear and pooled the tokens as below:
self.attention_weights = nn.Linear(dim, 1, bias=False) attention_scores = F.softmax(self.attention_weights(sos_tokens), dim=1) pooled_sos_token = torch.sum(attention_scores * sos_tokens, dim=1, keepdim=True)
It worked very good when I dealt with 350 objects but when I scaled up to 2k objects it didn't work as good ๐ญ I tried implementing more then one layer but it got worse when I made it too complex. Also I'm not quite sure about doing mean over the tokens since that might promote similar tokens.
Do you have any other idea then to prepend the tokenized text? ๐ It really hurts the performance if it needs to predict over 66k (50k+16k) tokens.
Here is some results, I took 355 samples and for each sample I found other items in the dataset that had the same type of label e.g. "chair" and "tall chair". This would give me a good idea about how well it keeps to the text condition.
As you can see, the attention one had the best results using 2 or 4 tokens, however the commit you did before also had some good results.
Latest commit using 4 sos tokens try not pooling all the sos tokens and instead using the last one Token 1: Correct = 169, Incorrect = 186 Token 2: Correct = 234, Incorrect = 121 Token 3: Correct = 353, Incorrect = 2 Token 4: Correct = 354, Incorrect = 1
The commit before the last one move the concatenation of the sos token so it is always Token 1: Correct = 258, Incorrect = 97 Token 2: Correct = 215, Incorrect = 140 Token 3: Correct = 354, Incorrect = 1 Token 4: Correct = 354, Incorrect = 1
All below, attention linear layer:
1 num_sos_tokens: Token 1: Correct = 74, Incorrect = 281 Token 2: Correct = 205, Incorrect = 150 Token 3: Correct = 352, Incorrect = 3 Token 4: Correct = 354, Incorrect = 1
2 num_sos_tokens: Token 1: Correct = 237, Incorrect = 118 Token 2: Correct = 220, Incorrect = 135 Token 3: Correct = 353, Incorrect = 2 Token 4: Correct = 354, Incorrect = 1
4 num_sos_tokens: Token 1: Correct = 237, Incorrect = 118 Token 2: Correct = 205, Incorrect = 150 Token 3: Correct = 351, Incorrect = 4 Token 4: Correct = 354, Incorrect = 1
8 num_sos_tokens: Token 1: Correct = 126, Incorrect = 229 Token 2: Correct = 220, Incorrect = 135 Token 3: Correct = 351, Incorrect = 4 Token 4: Correct = 354, Incorrect = 1
4 num_sos_tokens using mean:
pooled_sos_token = reduce(sos_tokens, 'b n d -> b 1 d', 'mean') attention_scores = F.softmax(self.attention_weights(pooled_sos_token ), dim=1) pooled_sos_token = torch.sum(attention_scores * sos_tokens, dim=1, keepdim=True)
Token 1: Correct = 158, Incorrect = 197 Token 2: Correct = 214, Incorrect = 141 Token 3: Correct = 344, Incorrect = 11 Token 4: Correct = 350, Incorrect = 5 Token 5: Correct = 355, Incorrect = 0
you successfully applied the attention pooling from enformer! ๐ ๐
thank you for the breakdown, going to default the number of sos tokens to 4 ๐
@MarcusLoppe awesome! did you mean that you used an extra linear on the mean pooled sos tokens?
Hey, so I added a Linear and pooled the tokens as below:
self.attention_weights = nn.Linear(dim, 1, bias=False) attention_scores = F.softmax(self.attention_weights(sos_tokens), dim=1) pooled_sos_token = torch.sum(attention_scores * sos_tokens, dim=1, keepdim=True)
It worked very good when I dealt with 350 objects but when I scaled up to 2k objects it didn't work as good ๐ญ I tried implementing more then one layer but it got worse when I made it too complex. Also I'm not quite sure about doing mean over the tokens since that might promote similar tokens.
Do you have any other idea then to prepend the tokenized text? ๐ It really hurts the performance if it needs to predict over 66k (50k+16k) tokens.
Here is some results, I took 355 samples and for each sample I found other items in the dataset that had the same type of label e.g. "chair" and "tall chair". This would give me a good idea about how well it keeps to the text condition.
As you can see, the attention one had the best results using 2 or 4 tokens, however the commit you did before also had some good results.
Latest commit using 4 sos tokens try not pooling all the sos tokens and instead using the last one Token 1: Correct = 169, Incorrect = 186 Token 2: Correct = 234, Incorrect = 121 Token 3: Correct = 353, Incorrect = 2 Token 4: Correct = 354, Incorrect = 1
The commit before the last one move the concatenation of the sos token so it is always Token 1: Correct = 258, Incorrect = 97 Token 2: Correct = 215, Incorrect = 140 Token 3: Correct = 354, Incorrect = 1 Token 4: Correct = 354, Incorrect = 1
All below, attention linear layer:
1 num_sos_tokens: Token 1: Correct = 74, Incorrect = 281 Token 2: Correct = 205, Incorrect = 150 Token 3: Correct = 352, Incorrect = 3 Token 4: Correct = 354, Incorrect = 1
2 num_sos_tokens: Token 1: Correct = 237, Incorrect = 118 Token 2: Correct = 220, Incorrect = 135 Token 3: Correct = 353, Incorrect = 2 Token 4: Correct = 354, Incorrect = 1
4 num_sos_tokens: Token 1: Correct = 237, Incorrect = 118 Token 2: Correct = 205, Incorrect = 150 Token 3: Correct = 351, Incorrect = 4 Token 4: Correct = 354, Incorrect = 1
8 num_sos_tokens: Token 1: Correct = 126, Incorrect = 229 Token 2: Correct = 220, Incorrect = 135 Token 3: Correct = 351, Incorrect = 4 Token 4: Correct = 354, Incorrect = 1
4 num_sos_tokens using mean:pooled_sos_token = reduce(sos_tokens, 'b n d -> b 1 d', 'mean') attention_scores = F.softmax(self.attention_weights(pooled_sos_token ), dim=1) pooled_sos_token = torch.sum(attention_scores * sos_tokens, dim=1, keepdim=True)
Token 1: Correct = 158, Incorrect = 197 Token 2: Correct = 214, Incorrect = 141 Token 3: Correct = 344, Incorrect = 11 Token 4: Correct = 350, Incorrect = 5 Token 5: Correct = 355, Incorrect = 0
you successfully applied the attention pooling from enformer! ๐ ๐
thank you for the breakdown, going to default the number of sos tokens to 4 ๐
I think you misunderstood me, the results are bad since the first tokens most of the cases was not remotely connected. When I used the exact same label to find valid codes it failed very hard.
As you can see, after the 3rd it gets 0-2 incorrect and as the sequence gets longer the better accuracy it has, after the 10th it maybe get 2 incorrect every 10 or so.
The showcase of the tests is this, it had 2000% better accuracy at the 3rd token vs 1 or 2 token.
The issue is still alive I'm afraid, it basically just throws out a guess at the first tokens.
I only trained on the first 36 tokens so I could speed up the testing, but currently I'm training it on the full sequence so I can show you the result of the generations.
I'll post the results later on
@MarcusLoppe ah, you aren't referring to the number of sos tokens, but to the token number in the main sequence, my bad
try with a much larger number of sos tokens, say 16 or 32
@MarcusLoppe ah, you aren't referring to the number of sos tokens, but to the token number in the main sequence, my bad
try with a much larger number of sos tokens, say 16 or 32
I don't have the figures for them but I tried 16 and got bad results.
As you can see using 8 tokens had the worst results.
I'll shoot up the test script and get you some hard numbers.
I know that setting up the sos tokens before the decoder and then inserting after the cross attention will create some sort of learnable relationship and I assume that the tokens change with loss.
However I don't have any data to back this up but isn't it better to have the tokens be a representation of the text embeddings?
If the sequence is 48 tokens the majority of the loss comes after the frist few tokens and will 'shape'/optimize to minimize that loss, meaning that the tokens will adapt to fit itself to work for 98% of the sequence.
Sort of like sacrifice the frist wave of soldiers in war to be on a better situation so no other soldiers need to die.
So is it possible to reshape (with any nn) the text embeddings to the dim size and then inserting them at the start of the sequence and then a special token?
ok, I'll try one more thing, but if that doesn't work, let's go for your text prefix solution
ok, I'll try one more thing, but if that doesn't work, let's go for your text prefix solution
Little bit off topic but I trained a 1 quantize auto-encoder and transformer and good results. It was a little slower progression but I got about 0.03 loss with the transformer.
I didn't succeed in generating mesh with 0 tokens but providing 10 tokens it managed to generate mesh :)
So that is a big win, halfing the sequence length and reducing vram requirement from 22 GB to 8 GB in training (800 faces)
ok, I'll try one more thing, but if that doesn't work, let's go for your text prefix solution
Hi again.
Here is some failed results:
- I've tried training with a higher loss at the first token or even only applying loss for the first token.
- I did some testing with text embeddings which I used the text_embeds to append at start or use a linear layer(dim, dim * num_tokens) and then rearranged to (b,num_tokens, dim)
- Replaced the sos_tokens with text_embeds
- Tested using 32 tokens
I was wonder if even the decoder cross-attention layer could handle it alone but with just the decoder layer couldn't handle any part of the sequence.
So what thinking with the cross-attention? Do you think the sos token or cross-attention can handle the cold start?
Since the issue is with the first 0-3 tokens, would it beneficial to create some kind of embedding space that contains the first 3 tokens and is indexed by text embedding, this way the text embedding provided by the user can be used to find the nearest neighbour.
It's not very novel but a good way to at-least kickstart the generation, although the issue might be resolved with scale later on.
The best result I got was with the commit below, however It may just be luck and not a consistent behaviour. The linear attention method had similar results but without the slowness of adding cross-attention to the fine-decoder.
Training many many epochs using add cross attention based text conditioning for fine transformer too
Token 1: Correct = 260, Incorrect = 95
Token 2: Correct = 319, Incorrect = 36
Token 3: Correct = 351, Incorrect = 4
Token 4: Correct = 355, Incorrect = 0
Linear layer with 4 sos tokens
if not exists(cache):
sos_tokens, attended_face_codes = unpack(attended_face_codes, packed_sos_shape, 'b * d')
attention_scores = F.softmax(self.attention_weights(sos_tokens), dim=1)
pooled_sos_token = torch.sum(attention_scores * sos_tokens, dim=1, keepdim=True)
attended_face_codes = torch.cat((pooled_sos_token, attended_face_codes), dim = 1)
Token 1: Correct = 237, Incorrect = 118
Token 2: Correct = 205, Incorrect = 150
Token 3: Correct = 351, Incorrect = 4
Token 4: Correct = 354, Incorrect = 1
Token 5: Correct = 355, Incorrect = 0
@MarcusLoppe thank you Marcus! ๐ will get a few solutions in soon
i also realized i wasn't caching the cross attention key / values correctly ๐คฆ will also get that fixed this morning
@MarcusLoppe thank you Marcus! ๐ will get a few solutions in soon
i also realized i wasn't caching the cross attention key / values correctly ๐คฆ will also get that fixed this morning
Awesome! ๐
Outside of meshgpt have you had success training the decoder and let it generate from cold start with just a embedding before? E.g. train on sequences with 6 tokens and the only input is a embedding that is used in the cross attention for the decoder.
It works kinda good when the dataset is small (<500) , I don't think it's the model size since it can remember 10k models if its prompted with a few tokens.
Btw let me know if I'm doing something wrong but during my testing I just call forward_on_codes and get the logits and get the token by argmax.
I'm not sure if this would disable the classifier guidance or not.
@MarcusLoppe thank you Marcus! ๐ will get a few solutions in soon
i also realized i wasn't caching the cross attention key / values correctly ๐คฆ will also get that fixed this morning
Hey again,
So I've noticed some strange behaviour with the cross attention num_mem_kv that might help you resolve the issue.
I've previously changed the value before without any noticeable changes.
However using the commit with the fine-decoder cross-attention I found the results below.
Setting the num_mem_kv cross attention to 16 seems to be hitting some kind of sweet spot (maybe related to the dataset size).
This made it possible to generate mesh from token 0 since it seems to be hitting the correct tokens, however as you can see the mesh is hardly smooth but at least it's selecting the correct first token! I'm currently training to see if using x5 augmentation of the same dataset will yield any better results since it might be more robust.
I also tested fine depth either to 4 or 8 but the effect worsen the performance, same goes with increasing the attn_num_mem_kv to 16.
I also tested using 16 cross_attn_num_mem_kv on all the other solutions you've posted but there was no noticeable changes.
Commit: 5ef6cbf
8 cross_attn_num_mem_kv
Token 1: Correct = 6, Incorrect = 349
Token 2: Correct = 165, Incorrect = 190
Token 3: Correct = 320, Incorrect = 35
Token 4: Correct = 322, Incorrect = 33
Token 5: Correct = 341, Incorrect = 14
16 cross_attn_num_mem_kv
Token 1: Correct = 293, Incorrect = 62
Token 2: Correct = 331, Incorrect = 24
Token 3: Correct = 354, Incorrect = 1
Token 4: Correct = 354, Incorrect = 1
Token 5: Correct = 355, Incorrect = 0
16 cross_attn_num_mem_kv
8 fine_attn_depth
Token 1: Correct = 233, Incorrect = 122
Token 2: Correct = 189, Incorrect = 166
Token 3: Correct = 321, Incorrect = 34
Token 4: Correct = 313, Incorrect = 42
Token 5: Correct = 342, Incorrect = 13
32 cross_attn_num_mem_kv
Token 1: Correct = 4, Incorrect = 351
Token 2: Correct = 207, Incorrect = 148
Token 3: Correct = 345, Incorrect = 10
Token 4: Correct = 338, Incorrect = 17
Token 5: Correct = 349, Incorrect = 6
16 attn_num_mem_kv
16 cross_attn_num_mem_kv
Token 1: Correct = 5, Incorrect = 350
Token 2: Correct = 205, Incorrect = 150
Token 3: Correct = 353, Incorrect = 2
Token 4: Correct = 355, Incorrect = 0
Token 5: Correct = 355, Incorrect = 0
@MarcusLoppe thank you Marcus! ๐ will get a few solutions in soon
i also realized i wasn't caching the cross attention key / values correctly ๐คฆ will also get that fixed this morning
Hey, @lucidrains
I think I've figured something out, I quite a lot changes but I had success by applying the following:
- Text embedding pooling with to_sos_text_cond
- sos_token (single parameter).
- Fine decoder cross-attention
Plus a few other tricks.
The training is also quite specific in regards to masking the text and other factors, if it becomes overtrained then the results are just blobs again.
When the conditions are pretty good the model will always generate a complete shape, not always for what you want but at least it's not a blob.
Btw I also manage to train a model using 1 quantizer which reduced the inference time by half (duh :) ).
I wouldn't say this issue is resolved since using a dataset with 1k unique labels, during the generation it will steer towards the most average mesh model according the the text embeddings, you can see this average effect in the second image (cond scale helps sometimes, setting it too high will turn the mesh into a blob).
Hopefully this information helps you steer towards a final solution that can be used for a large of amount text labels.
Possible issue / accidental feature
I'm not sure if it's a problem but since I add the sos_token before the main decoder and then adding the text embedding pooling afterwards, it will results in 2 tokens with 'value' is added and with the padding it will be 12 tokens.
The first 6 extra tokens are due for the autoregressive and the other 6 is due to the text embedding pool since it's added just before pad_to_length is called.
The results is that 1 token will be replaced/lost due to the right shift since the 2 tokens are added and only the sos_token is removed.
So the data between the decoder and fine decoder will be shifted right and the becomes in another order, this might not be a issue for the fine decoder since it's already out of order due to the rearranging and adding the grouped_codes so the shape goes from (b, faces, dim)
to (b * (faces+1), (quantizers * vertices_per_face), dim)
But if you think of in a linear fashion and ignoring the ML transforming the data, the output would be:
<pooled_text_embed> <mesh> <cut> <EOS> <extra tokens>
Instead of:
<mesh> <EOS> <cut> <extra tokens>
This is just a guess but maybe since the output is over a longer sequence window during (12 tokens in the future instead of 6), it might help with the inference since during training it outputs what it thinks might be after the EOS token. However this output is cut off and doesn't affect the loss so I'm not sure if it matters, I also increased the padding so it's 18 tokens but the performance degraded).
I also tested replacing the pooled_text_embed with a Parameter dim but it got worse results so the text embedding does affect the output.
Multi-token prediction
I've been trying to understand how the transformer train and at the end there is always 1 extra face (6 tokens) and then the sequence is cut of so it's 5 tokens remaining. I'm guessing this is done for the autoregression and the EOS token.
But I think it can provide a additional effect by extending 'hidden' future tokens and can be used multi-token prediction.
I'm not sure about where the masking is applied while training but as a test I increase the amount of codes that was cut off and set 'append_eos' to false to see if it can predict multiple tokens ahead.
Nothing fancy as the meta paper and just a weak proof of concept.
Here is some samples after training 15 epochs on the first 12 tokens on 2.8k meshes with 1000 labels:
1 tokens: 0.3990 loss (0.5574 loss without the text embedding pooling)
2 tokens: 0.112 loss
3 tokens: 0.24 loss
4 tokens: 0.1375 loss (woah!)
6 tokens: 0.1826 loss (18th epoch 0.104 loss)
if return_loss:
assert seq_len > 0
codes, labels = codes[:, :-number_of_tokens], codes
.......
embed = embed[:, :(code_len + number_of_tokens)]
500 labels with 10 models for each label- 2k codebook, number of quantizers: 2
1000 labels with 5 models for each label- 2k codebook, number of quantizers: 2
100 labels with 25 models for each label- 16k codebook, number of quantizers: 1
@MarcusLoppe thanks Marcus for the detailed report and glad to hear you found a solution!
i'll keep chipping away at it to strengthen conditioning
next up is to probably add adaptive layer/rms normalization to x-transformers
@MarcusLoppe you should see a slight speed bump on inference now, as i now cache the key / values correctly for cross attention
hope i didn't break anything!
@MarcusLoppe thanks Marcus for the detailed report and glad to hear you found a solution!
i'll keep chipping away at it to strengthen conditioning
next up is to probably add adaptive layer/rms normalization to
x-transformers
Lovely :) I'll test the FILM normalization method and let you know.
I tried replacing the the PixelNorm using the film batch normalization on the ResNet but I had mild success, the std for the first token and it's relationship to the label decreased from 12 to 11.
However the sos_token isn't quite how I implemented it, I've had more success in just leaving the sos_token in without unpacking it.
I had best success with using the sos_token as per your commit move the concatenation of the sos token so it is always conditioned b.
That commit have had far better results rather then:
- Unpacking multiple tokens + packing pooling
- Repacking single
- Unpacking multiple tokens and packing last token.
I tried explaining it before with my tests but I might have not been clear enough.
Here is the implementation I've used
https://github.com/MarcusLoppe/meshgpt-pytorch/blob/sos_token_test2/meshgpt_pytorch/meshgpt_pytorch.py
@MarcusLoppe you should see a slight speed bump on inference now, as i now cache the key / values correctly for cross attention
hope i didn't break anything!
I'll give it a go :)
Btw I can create a pull request for this but when using 1 quantizer, the rounding down method doesn't work in generation. Currently it's doing:
10 codes / 1 = 10 * 1 = 10 codes.
Instead of doing:
10 codes / 3 = 3 * 3 = 9 codes.
So changing the below will made the 1 quantizer generation work.
From:
round_down_code_len = code_len // self.num_quantizers * self.num_quantizers
To:
round_down_code_len = code_len // self.num_vertices_per_face * self.num_vertices_per_face
@MarcusLoppe thanks Marcus for the detailed report and glad to hear you found a solution!
i'll keep chipping away at it to strengthen conditioning
next up is to probably add adaptive layer/rms normalization tox-transformers
Lovely :) I'll test the FILM normalization method and let you know. I tried replacing the the PixelNorm using the film batch normalization on the ResNet but I had mild success, the std for the first token and it's relationship to the label decreased from 12 to 11.
However the sos_token isn't quite how I implemented it, I've had more success in just leaving the sos_token in without unpacking it. I had best success with using the sos_token as per your commit move the concatenation of the sos token so it is always conditioned b. That commit have had far better results rather then:
- Unpacking multiple tokens + packing pooling
- Repacking single
- Unpacking multiple tokens and packing last token.
I tried explaining it before with my tests but I might have not been clear enough.
Here is the implementation I've used https://github.com/MarcusLoppe/meshgpt-pytorch/blob/sos_token_test2/meshgpt_pytorch/meshgpt_pytorch.py
@MarcusLoppe you should see a slight speed bump on inference now, as i now cache the key / values correctly for cross attention
hope i didn't break anything!I'll give it a go :) Btw I can create a pull request for this but when using 1 quantizer, the rounding down method doesn't work in generation. Currently it's doing:
10 codes / 1 = 10 * 1 = 10 codes.
Instead of doing:10 codes / 3 = 3 * 3 = 9 codes.
So changing the below will made the 1 quantizer generation work. From:
round_down_code_len = code_len // self.num_quantizers * self.num_quantizers
To:round_down_code_len = code_len // self.num_vertices_per_face * self.num_vertices_per_face
thanks for reporting the rounding down issue!
and yes, i can cleanup the multiple sos tokens code if not needed. however, by setting just 1 sos token, it should be equivalent to what you deem the best working commit
thanks for reporting the rounding down issue!
and yes, i can cleanup the multiple sos tokens code if not needed. however, by setting just 1 sos token, it should be equivalent to what you deem the best working commit
Just tested it on the first 12 tokens and using the FiLM + mean have worse performance plus it's giving me nan loss.
However it might work better when applied to the full sequence, I'll get back to you.
Edit: Seems like the parameter usage in the FilM layer improved it, it no longer gives nan loss at the start.
Although I'm not sure if it would help since the issue might be pooling the mean. Lets say 1000's of text embeddings which all are unique, the cross-attention will receive them in their original state but then the fine decoder will get the average of each embedding as a additional token.
After averaging the embeddings they are now closer to each other then before and some of them no longer unique.
I think this is the reason why the same model is the 'default' for several mesh models.
@MarcusLoppe thanks for running it, made a few changes
yea, can definitely try attention pooling, which is a step up from mean pool
@MarcusLoppe thanks for running it, made a few changes
yea, can definitely try attention pooling, which is a step up from mean pool
Okay, some updates.
-
I noticed that the masked_mean was broken since classifier free guidance uses the pad id "0" instead of "-1", so the embeddings never stayed the same for each batch since it included the padded values when doing the mean.
This would mean that the 'chair' embedding would have different values and effectively be random.
I created a pull request to add the padding id: Added padding id option -
After the padding issue was resolved I still had meshes with text that was more popular then other (e.g. it matched with "pallet" for 183 times of 775). So I printed out the cosine similarities and noticed a pattern, 'chair' and 'pallet' had 0.99999 cosine similarity and 'pallet' had 0.99999 similarity with many others. I was using CLIP so I switched to T5/BGE and the similarity was around 0.69 as it should be.
I tried to combined CLIP & T5 usingmodel_types = ('t5', 'clip'),
however the same issue remained.
CLIP seems to work better then T5 & BGE on longer sentences and contain more nuanced information.
Here is the cosine similarities (checkout the last 2):
['pallet', 'chair']
BGE 58.8942
T5 67.7801
CLIP 99.9937
['a pallet', 'a chair']
BGE 71.565
T5 60.0675
CLIP 81.1785
['a pallet on floor', 'a chair on floor']
BGE 81.9622
T5 40.6449
CLIP 76.7676
['pallet', 'pallet on floor']
BGE 89.8585
T5 50.7333
CLIP 78.1393
['chair', 'pallet on floor']
BGE 60.5253
T5 37.083
CLIP 78.2081
So after fixing these changes I tried again and had much better success :)
I tested FiLM and got mixed results, I'm not confident to say what is best since the 1.5k tokens tests says it's best with FiLM on but the training without FiLM is better.
The results are very good, however there is some issues such as the test using x5 models per label, I have a very had time to generate 3 sets of distinct rows of the same furniture labels.
It does follow the label but it seems to be ignoring the other models (even with cond_scale and high temperature 0.8).
I trained on a dataset using 775 labels with 5x examples each (3.8k meshes), first tests was only trained on 60 tokens total, then the latest one I trained on the full 1500 token sequences. I tested using x10 examples but that training run requires more time to get a accurate picture of it's performance.
For calculating the accuracy I took all the unique labels and created list of all the models that have the same label. Using this list I checked if any of these models contains the same 3 first tokens as the generated sequence, if so. I count it as "correct".
Using FiLM
T5, trained on 60 tokens:
Generation accuracy:
Accuracy: 0.9987096774193548 all_correct: 774 len(test_dataset): 775
Token 1: Correct = 774, Incorrect = 1
Token 2: Correct = 774, Incorrect = 1
Token 3: Correct = 774, Incorrect = 1
Forward accuracy:
Accuracy: 0.27225806451612905 all_correct: 211 len(test_dataset): 775
Token 1: Correct = 774, Incorrect = 1
Token 2: Correct = 775, Incorrect = 0
Token 3: Correct = 775, Incorrect = 0
BGE, trained on 60 tokens:
Generation accuracy:
Accuracy: 1.0 all_correct: 775 len(test_dataset): 775
Token 1: Correct = 775, Incorrect = 0
Token 2: Correct = 775, Incorrect = 0
Token 3: Correct = 775, Incorrect = 0
Forward accuracy:
Accuracy: 0.36 all_correct: 279 len(test_dataset): 775
Token 1: Correct = 775, Incorrect = 0
Token 2: Correct = 775, Incorrect = 0
Token 3: Correct = 775, Incorrect = 0
BGE trained on 1500 tokens:
Generation accuracy:
Accuracy: 0.984516129032258 all_correct: 763 len(test_dataset): 775
Token 1: Correct = 763, Incorrect = 12
Token 2: Correct = 763, Incorrect = 12
Token 3: Correct = 763, Incorrect = 12
Generation - Super loose (cosine similarity with text 90%)
Accuracy: 0.9922580645161291 all_correct: 769 len(test_dataset): 775
Token 1: Correct = 770, Incorrect = 5
Token 2: Correct = 769, Incorrect = 6
Token 3: Correct = 769, Incorrect = 6
Forward accuracy:
Accuracy: 0.28774193548387095 all_correct: 223 len(test_dataset): 775
Token 1: Correct = 760, Incorrect = 15
Token 2: Correct = 775, Incorrect = 0
Token 3: Correct = 775, Incorrect = 0
Without FiLM
T5, trained on 60 tokens:
Generation accuracy:
Accuracy: 0.9987096774193548 all_correct: 774 len(test_dataset): 775
Token 1: Correct = 774, Incorrect = 1
Token 2: Correct = 774, Incorrect = 1
Token 3: Correct = 774, Incorrect = 1
Forward accuracy:
Accuracy: 0.34580645161290324 all_correct: 268 len(test_dataset): 775
Token 1: Correct = 774, Incorrect = 1
Token 2: Correct = 775, Incorrect = 0
Token 3: Correct = 775, Incorrect = 0
BGE, trained on 60 tokens:
Generation accuracy:
Accuracy: 1.0 all_correct: 775 len(test_dataset): 775
Token 1: Correct = 775, Incorrect = 0
Token 2: Correct = 775, Incorrect = 0
Token 3: Correct = 775, Incorrect = 0
Forward accuracy:
Accuracy: 0.33419354838709675 all_correct: 259 len(test_dataset): 775
Token 1: Correct = 775, Incorrect = 0
Token 2: Correct = 775, Incorrect = 0
Token 3: Correct = 775, Incorrect = 0
BGE trained on 1500 tokens:
Generation accuracy:
Accuracy: 0.9780645161290322 all_correct: 758 len(test_dataset): 775
Token 1: Correct = 758, Incorrect = 17
Token 2: Correct = 758, Incorrect = 17
Token 3: Correct = 758, Incorrect = 17
Generation - Super loose (cosine similarity with text 90%)
Accuracy: 0.9883870967741936 all_correct: 766 len(test_dataset): 775
Token 1: Correct = 766, Incorrect = 9
Token 2: Correct = 766, Incorrect = 9
Token 3: Correct = 766, Incorrect = 9
Forward accuracy:
Accuracy: 0.28774193548387095 all_correct: 223 len(test_dataset): 775
Token 1: Correct = 746, Incorrect = 29
Token 2: Correct = 775, Incorrect = 0
Token 3: Correct = 775, Incorrect = 0
775 labels with 10x examples each (7.7k meshes),
BGE trained on 1500 tokens:
Generation accuracy:
Accuracy: 0.9858064516129033 all_correct: 764 len(test_dataset): 775
Token 1: Correct = 764, Incorrect = 11
Token 2: Correct = 764, Incorrect = 11
Token 3: Correct = 764, Incorrect = 11
Generation - Super loose (cosine similarity with text 90%)
Accuracy: 0.9896774193548387 all_correct: 767 len(test_dataset): 775
Token 1: Correct = 767, Incorrect = 8
Token 2: Correct = 767, Incorrect = 8
Token 3: Correct = 767, Incorrect = 8
Token 4: Correct = 767, Incorrect = 8
Renders:
Mesh files: https://file.io/OxWCcTYpUbxN
Without FILM
['bed', 'sofa', 'monitor', 'bench', 'chair', 'table', 'console table console', 'object on a stick', 'knife', 'billboard', 'concrete structure', 'rod', 'stick with a handle', 'shark fin', 'wooden railing', 'zigzag chair straight chair side chair', 'building', 'building with a few windows', 'wooden bench', 'screen with nothing on it', 'bird with a beak', 'apple', 'trash can with a lid', 'blocky robot', 'crystal on a base', 'sign with arrows on it', 'three different colocubes', 'crowbar', 'sheet of paper', 'metal rod', 'computer screen computer display screen crt screen', 'octagon', 'chair on a floor', 'platform bed', 'dark object', 'u shaped object', 'bed with a headboard', 'three blocks', 'brush', 'whale tail', 'staircase', 'lamp', 'broken cylinder', 'night stand', 'traffic cone', 'drill', 'four geometric shapes', 'trash can', 'rocket ship', 'traffic cone on a base', 'pyramid and a square', 'robot that is standing up', 'computer monitor', 'pixelated object', 'coffee cup', 'chaise longue chaise daybed chair', 'three rocks', 'desk table-tennis table ping-pong table pingpong table', 'letter k', 'box with a ribbon', 'dice with dots on it']
With FILM
['bed', 'sofa', 'monitor', 'bench', 'chair', 'table', 'block animal', 'car', 'book', 'minecraft character in a shirt', 'pot', 'hanging sign', 'minecraft character wearing a shirt', 'secretary writing table escritoire secretaire', 'apple with a stem', 'two colorful blocks', 'line', 'security camera', 'pan', 'armchair', 'picket fence', 'palm tree', 'the letter t', 'bottle', 'park bench bench', 'sign', 'wooden baseball bat', 'wooden table', 'container with a lid', 'picnic table with two benches', 'geometric tree', 'operating table', 'traffic light', 'checkechair', 'dumpster', 'group of cubes', 'chair on a floor', 'two rectangular objects', 'box sitting', 'blocky chicken with an beak', 'octagon', 'couch with arm rests', 'metal object', 'top hat', 'diamond shaped object', 'two shelves', 'two rocks', 'apple', 'cartoon snail carrying a box', 'bench', 'cylinder and a cube', 'minecraft dolphin', 'question mark', 'tree made out of blocks', 'display video display', 'wooden stool', 'box with stripes on it', 'metal rod', 'tree stump', 'three dimensional object', 'side table table']
@MarcusLoppe awesome, i think after adding adaptive layernorms to x-transformers
, we can close this issue
that will surely be enough
@MarcusLoppe awesome, i think after adding adaptive layernorms to
x-transformers
, we can close this issuethat will surely be enough
@lucidrains
A little question, I'm about to train larger model that would require many text embeddings and I'm bit worried that cross-attention with the text embedding might 'take up' to much of the model.
I was thinking of using a dataset of 150k models so the text-guidance will need to be extremely good to represent those meshes.
I have this idea that due to the cross-attention to a text embedding which have a many to many relationship with tokens, if it instead just used the cross-attention to the sequence itself it will have more or else one to many relationship.
So if we used the learnable tokens and compress the text embedding through it so i would represent the text within the sequence, then it could cross-attend the sequence itself and make it more efficient and simpler.
I was wondering if using something like below would work?
The multiple sos tokens are just kept for the main decoder but isn't for the rest of the network and there is no Q-Former architecture that takes the text embeddings and encodes the information to it.
@MarcusLoppe oh, i don't even know what the q-former architecture is haha
i'll have to read it later this week, but it sounds like just a cross attention based recompression, similar to perceiver resampler
just got the adaptive layernorm conditioning into the repo! i think we can safely close this issue
@MarcusLoppe we can chat about q-former in the discussions tab
@MarcusLoppe oh yes, the qformer architecture is in vogue. bytedance recently used it for their vqvae to compress images even further.
will explore this in the coming month for sure!
@MarcusLoppe oh yes, the qformer architecture is in vogue. bytedance recently used it for their vqvae to compress images even further.
will explore this in the coming month for sure!
Would it be possible to explore this sooner? :) Or maybe provide a hint on how to do this?
A student of a university in the states have offered to help using 16 H100's for 2 weeks!
He'll be granted the compute shortly (today) and we'll start the autoencoder training.
However I'm still bit unsure if the cross-attention is the best way since the I had some trouble with using it for 10k labels.
I don't think meshgpt will get this attention again so any advice would be helpful! :)
@MarcusLoppe is he/she a phd or ms student? if so, you and him/her should be able to work together and implement it, could even make for a short paper
or i can take a look at it, but probably not for another month or two
@MarcusLoppe is he/she a phd or ms student? if so, you and him/her should be able to work together and implement it, could even make for a short paper
or i can take a look at it, but probably not for another month or two
I think he's a PHD student, he applied for the compute a while ago and was granted it. It's not for a thesis or graded paper but perhaps a technical report.
I'm happily and interested it in implementing it myself but with many of SOTA things it might be above my head.
As far as I understand it, the image patches (image / 32) in processed through a encoder which then are encoded & quantized into a codebook. Then these codes can be used by the transformer as token indices or maybe just the encoded dim output from the quantizer.
I'm on the right track?
@MarcusLoppe ok, if he's a phd student, you two should definitely be able to work it out from the code already available
@MarcusLoppe i'm not sure without spending a day reading the paper, but it looks to me they are simply using appended "query" tokens, which is similar to memory / register tokens in the literature. they simply concat it to the sequence and then attend to everything, and slice it out. it is similar to the sos tokens we've been playing around with, except it isn't autoregressive
@MarcusLoppe ask your collaborator! he should know if he is in the field
@MarcusLoppe i'm not sure without spending a day reading the paper, but it looks to me they are simply using appended "query" tokens, which is similar to memory / register tokens in the literature. they simply concat it to the sequence and then attend to everything, and slice it out. it is similar to the sos tokens we've been playing around with, except it isn't autoregressive
I've read bit further, you might be right and I'm not understanding your terminology.
My understanding it that they train a autoencoder (tokenizer) and only use 32 tokens to represent the image.
They make extract patches for the image and represent them in token(s) and the during generation they can mask certain tokens and generate new novel images (?).
I'm not quite sure if it's applicable to this project, I played bit around with using the sos tokens,however I got worse results.
I used a encoder layer that uses the text embeddings as cross-attention to output the dim for the sos tokens.
I also tried using [sos_tokens , text_embedding] and encoded it and returned the output from the encoder [:32]
I was thinking that maybe the issue isn't that the text embeddings are too weak but maybe the cross-attention will messes it up a bit.
Since the cross-attention to the text embeddings isn't one to one, it might be unstable since it will learn that one text embedding have 100's of different 'correct solutions'.
I think it will be more stable if the text is within the token sequence, the sos tokens hasn't quite worked out.
@MarcusLoppe you should def chat with your collaborator (instead of me) since you'll be training the next model together
he will probably be more up-to-date with mesh research too, as he is following it full time