Hi, I tried to modify the MLD-VAE from skipTransformer to the normal transformer, but I found that the loss doesn't drop at around total loss = 0.17. But I found that it would give a similar result to you if I used the batch_first = True. Could you please explain a little bit about it?
Best,
Zen