Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion mup/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@
from mup.infshape import *
from mup.init import *
from mup.layer import *
from mup.optim import *
from mup.optim import MuSGD, MuAdam, MuAdamW, process_param_groups
5 changes: 5 additions & 0 deletions mup/optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ def MuOptimizer(params, **kwargs):

from torch.optim import SGD, Adam, AdamW

import logging
logger = logging.getLogger(__name__)

def process_param_groups(params, **kwargs):
param_groups = list(params)
Expand Down Expand Up @@ -51,6 +53,9 @@ def MuAdam(params, impl=Adam, decoupled_wd=False, **kwargs):
An instance of `impl` with refined parameter groups, each of which has the correctly
scaled learning rate according to mup.
'''
if impl == Adam and kwargs.get('weight_decay', False):
logger.warning('MuAdam does not scale weight decay correctly. Use MuAdamW instead.')

new_param_groups = []
for param_group in process_param_groups(params, **kwargs):
# For every existing param group, we split into several new groups
Expand Down