lucidrains/x-transformers

Adding memmask to ContinuousTransformerWrapper

Closed this issue · 3 comments

@lucidrains could you add memmask to ContinuousTransformerWrapper please?
Thanks

Correct me if i'm wrong but the only difference between ContinuousTransformerWrapper and TransformerWrapper is one uses a nn.Linear layer to project inputs while the other uses an nn.Embedding layer at the input. Those should be the only differences right?

In which case, can't we delete ContinuousTransformerWrapper and simply make:

  • dim_in
  • dim_out
  • num_tokens

all optional?

Add some asserts like:

  • assert dim_im ^ num_tokens, "either project input or embed tokens. not both"
  • assert dim_out ^ num_tokens, "either project output or predict logits. not both"

Something like that. Then you can determine if it behaves more like ContinuousTransformerWrapper or like normal TransformerWrapper at runtime.

Also, it's my understanding that at the output, dim_out and num_tokens are doing the same thing. Either way, you use a nn.Linear layer. So you probs don't even need dim_out.

Anyway, this will stop me creating issues like : "can you add feature X from TransformerWrapper to ContinuousTransformerWrapper :)

@pfeatherstone oops, yea i added it

hmm, maybe at a later date, for now let us keep it separate

@pfeatherstone that is the only difference within this wrapper

other differences exist in the loss and sampling, for its autoregressive wrapper