Update distill.py to include device agnostic code for distill_mlp head and distillation_token#324
Open
vivekh2000 wants to merge 1 commit intolucidrains:mainfrom
Open
Update distill.py to include device agnostic code for distill_mlp head and distillation_token#324vivekh2000 wants to merge 1 commit intolucidrains:mainfrom
distill_mlp head and distillation_token#324vivekh2000 wants to merge 1 commit intolucidrains:mainfrom
Conversation
…ead and `distillation_token` Since in your code, `distillation_token` and `distill_mlp` head are defined in the DistillWrapper class, sending the model instance of the DistillableViT class to GPU. do not send them to GPU. While training a model using this code, I got a device mismatch error, which made it hard to figure out the source of the error. Finally, the `distillation_token` and `distill_mlp` turned out to be the culprits as they are not defined in the model class but in the DistillWrapper class. Therefore, I have suggested the following changes, when training a model on GPU, the training code should set the device="cude" if torch.cuda.is_available() else "cpu". or the same can be incorporated in the constructor of the DistillWrapper class.
19eb6d4 to
5e808f4
Compare
cbf6723 to
5cf8384
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Since in your code, the
distillation_tokenanddistill_mlpheads are defined in theDistillWrapperclass, sending the model instance of theDistillableViTclass to GPU does not send thedistillation_tokenanddistill_mlphead to GPU. Therefore, while training a model using this code, I got a device mismatch error, which made it hard to figure out the source of the error. Finally, thedistillation_tokenanddistill_mlpturned out to be the culprits as they are not defined in the model class but in theDistillWrapperclass, which is a wrapper of loss function. Therefore, I have suggested the following changes when training a model on GPU: the training code should set thedevice="cude" if torch.cuda.is_available() else "cpu", or the same can be incorporated into the constructor of theDistillWrapperclass.