Skip to content

Commit 8e08a04

Browse files
committed
basic NS
1 parent e7052af commit 8e08a04

File tree

1 file changed

+13
-0
lines changed

1 file changed

+13
-0
lines changed

modula/atom.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,19 @@ def regularize(self, w, strength):
4545
def print_submodules(self):
4646
print(f"Linear module of shape {(self.out_features, self.in_features)} and mass {self.mass}.")
4747

48+
class NSLinear(Linear):
49+
def __init__(self, out_features, in_features, mass=1, steps=10):
50+
super().__init__(out_features, in_features, mass)
51+
self.steps = steps
52+
53+
@torch.no_grad()
54+
def normalize(self, w, target_norm):
55+
weight = w[0]
56+
weight.div_(weight.norm())
57+
for _ in range(self.steps):
58+
# TODO: speed this up
59+
weight.data = 1.5 * weight - 0.5 * weight @ weight.t() @ weight
60+
weight.mul_(target_norm)
4861

4962
class Conv2D(Module):
5063
def __init__(self, out_channels, in_channels, kernel_size=3, stride=1, padding=1, mass=1):

0 commit comments

Comments
 (0)