|
1 | | -import copy |
2 | | -import torch |
3 | | - |
4 | | -from modula.vector import Vector |
| 1 | +import jax |
5 | 2 |
|
6 | 3 | class Module: |
7 | 4 | def __init__(self): |
8 | | - self.mass = None |
9 | | - self.sensitivity = None |
10 | | - self.length = None |
11 | 5 | self.children = [] |
12 | | - |
| 6 | + |
| 7 | + self.atoms = None # number of atoms: int |
| 8 | + self.bonds = None # number of bonds: int |
| 9 | + self.smooth = None # is this module smooth?: bool |
| 10 | + self.sensitivity = None # input Lipschitz estimate: float > 0 |
| 11 | + self.mass = None # proportional contribution of module toward feature learning of any supermodule: float >= 0 |
| 12 | + |
| 13 | + def __str__(self): |
| 14 | + string = self.__class__.__name__ |
| 15 | + string += f"\n...consists of {self.atoms} atoms and {self.bonds} bonds" |
| 16 | + string += f"\n...{'smooth' if self.smooth else 'non-smooth'}" |
| 17 | + string += f"\n...input sensitivity is {self.sensitivity}" |
| 18 | + string += f"\n...contributes proportion {self.mass} to feature learning of any supermodule" |
| 19 | + return string |
| 20 | + |
| 21 | + def tare(self, absolute=1.0, relative=None): |
| 22 | + if relative is None: |
| 23 | + self.tare(relative = absolute / self.mass) |
| 24 | + else: |
| 25 | + self.mass *= relative |
| 26 | + for m in self.children: |
| 27 | + m.tare(relative = relative) |
| 28 | + |
| 29 | + def jit(self): |
| 30 | + self.forward = jax.jit(self.forward) |
| 31 | + self.backward = jax.jit(self.backward) |
| 32 | + self.project = jax.jit(self.project) |
| 33 | + self.dualize = jax.jit(self.dualize) |
| 34 | + |
13 | 35 | def forward(self, x, w): |
| 36 | + # Input and weight list --> output and list of internal activations. |
14 | 37 | raise NotImplementedError |
15 | 38 |
|
16 | | - def initialize(self, device, dtype): |
| 39 | + def backward(self, w, grad_output): |
| 40 | + # Weight list and output gradient --> weight gradient list and input gradient. |
17 | 41 | raise NotImplementedError |
18 | 42 |
|
19 | | - def normalize(self, w, target_norm): |
| 43 | + def initialize(self, key): |
| 44 | + # Return a weight list. |
20 | 45 | raise NotImplementedError |
21 | 46 |
|
22 | | - def regularize(self, w, strength): |
| 47 | + def project(self, w): |
| 48 | + # Return a weight list. |
23 | 49 | raise NotImplementedError |
24 | 50 |
|
25 | | - def tare(self, absolute=1, relative=None): |
26 | | - if relative is not None: |
27 | | - self.mass *= relative |
28 | | - for child in self.children: |
29 | | - child.tare(relative = relative) |
30 | | - else: |
31 | | - self.tare(relative = absolute / self.mass) |
32 | | - |
33 | | - def print_submodules(self): |
34 | | - for child in self.children: |
35 | | - child.print_submodules() |
36 | | - |
37 | | - def __str__(self): |
38 | | - return f"Module of mass {self.mass} and sensitivity {self.sensitivity}." |
39 | | - |
40 | | - def __call__(self, x, w): |
41 | | - return self.forward(x, w) |
| 51 | + def dualize(self, grad_w, target_norm): |
| 52 | + # Weight gradient list and number --> normalized weight gradient list |
| 53 | + raise NotImplementedError |
42 | 54 |
|
43 | 55 | def __matmul__(self, other): |
44 | | - if isinstance(other, tuple): other = TupleModule(other) |
45 | 56 | return CompositeModule(self, other) |
46 | 57 |
|
47 | | - def __rmatmul__(self, other): |
48 | | - if isinstance(other, tuple): other = TupleModule(other) |
49 | | - return other @ self |
50 | | - |
51 | 58 | def __add__(self, other): |
52 | | - return Add() @ (self, other) |
| 59 | + return Add() @ TupleModule((self, other)) |
53 | 60 |
|
54 | | - def __mul__(self, other): |
55 | | - assert other != 0, "cannot multiply a module by zero" |
56 | | - return self @ Mul(other) |
| 61 | + def __rmul__(self, scalar): |
| 62 | + return Mul(scalar) @ self |
57 | 63 |
|
58 | | - def __rmul__(self, other): |
59 | | - assert other != 0, "cannot multiply a module by zero" |
60 | | - return Mul(other) @ self |
| 64 | + def __call__(self, x, w): |
| 65 | + return self.forward(x, w) |
61 | 66 |
|
62 | | - def __truediv__(self, other): |
63 | | - assert other != 0, "cannot divide a module by zero" |
64 | | - return self * (1/other) |
| 67 | +class Atom(Module): |
| 68 | + def __init__(self): |
| 69 | + super().__init__() |
| 70 | + self.atoms = 1 |
| 71 | + self.bonds = 0 |
65 | 72 |
|
66 | | - def __pow__(self, other): |
67 | | - assert other >= 0 and other % 1 == 0, "nonnegative integer powers only" |
68 | | - if other > 0: |
69 | | - return copy.deepcopy(self) @ self ** (other - 1) |
70 | | - else: |
71 | | - return Mul(1.0) |
| 73 | +class Bond(Module): |
| 74 | + def __init__(self): |
| 75 | + super().__init__() |
| 76 | + self.atoms = 0 |
| 77 | + self.bonds = 1 |
| 78 | + self.mass = 0 |
72 | 79 |
|
| 80 | + def initialize(self, key): |
| 81 | + return [] |
| 82 | + |
| 83 | + def project(self, w): |
| 84 | + return [] |
| 85 | + |
| 86 | + def dualize(self, grad_w, target_norm=1.0): |
| 87 | + return [] |
73 | 88 |
|
74 | 89 | class CompositeModule(Module): |
75 | 90 | def __init__(self, m1, m0): |
76 | 91 | super().__init__() |
77 | 92 | self.children = (m0, m1) |
78 | | - self.length = m0.length + m1.length |
79 | | - self.mass = m0.mass + m1.mass |
80 | | - self.sensitivity = m1.sensitivity * m0.sensitivity |
81 | | - |
| 93 | + |
| 94 | + self.atoms = m0.atoms + m1.atoms |
| 95 | + self.bonds = m0.bonds + m1.bonds |
| 96 | + self.smooth = m0.smooth and m1.smooth |
| 97 | + self.mass = m0.mass + m1.mass |
| 98 | + self.sensitivity = m0.sensitivity * m1.sensitivity |
| 99 | + |
82 | 100 | def forward(self, x, w): |
83 | 101 | m0, m1 = self.children |
84 | | - w0 = w[:m0.length] |
85 | | - w1 = w[m0.length:] |
86 | | - return m1.forward(m0.forward(x, w0), w1) |
| 102 | + w0 = w[:m0.atoms] |
| 103 | + w1 = w[m0.atoms:] |
| 104 | + x0, activations0 = m0.forward(x, w0) |
| 105 | + x1, activations1 = m1.forward(x0, w1) |
| 106 | + return x1, activations0 + activations1 |
87 | 107 |
|
88 | | - def initialize(self, device, dtype=torch.float32): |
| 108 | + def initialize(self, key): |
89 | 109 | m0, m1 = self.children |
90 | | - return m0.initialize(device, dtype=dtype) & m1.initialize(device, dtype=dtype) |
| 110 | + key, subkey = jax.random.split(key) |
| 111 | + return m0.initialize(key) + m1.initialize(subkey) |
91 | 112 |
|
92 | | - def normalize(self, w, target_norm=1): |
93 | | - if self.mass > 0: |
94 | | - m0, m1 = self.children |
95 | | - w0 = Vector(w[:m0.length]) |
96 | | - w1 = Vector(w[m0.length:]) |
97 | | - m0.normalize(w0, target_norm=m0.mass / self.mass * target_norm / m1.sensitivity) |
98 | | - m1.normalize(w1, target_norm=m1.mass / self.mass * target_norm) |
99 | | - else: |
100 | | - w *= 0 |
| 113 | + def project(self, w): |
| 114 | + m0, m1 = self.children |
| 115 | + w0 = w[:m0.atoms] |
| 116 | + w1 = w[m0.atoms:] |
| 117 | + return m0.project(w0) + m1.project(w1) |
| 118 | + |
| 119 | + def backward(self, w, acts, grad_output): |
| 120 | + m0, m1 = self.children |
| 121 | + w0 = w[:m0.atoms] |
| 122 | + w1 = w[m0.atoms:] |
| 123 | + acts0 = acts[:m0.atoms+m0.bonds] |
| 124 | + acts1 = acts[m0.atoms+m0.bonds:] |
| 125 | + |
| 126 | + grad_w1, grad_input1 = m1.backward(w1, acts1, grad_output) |
| 127 | + grad_w0, grad_input0 = m0.backward(w0, acts0, grad_input1) |
101 | 128 |
|
102 | | - def regularize(self, w, strength): |
| 129 | + return grad_w0 + grad_w1, grad_input0 |
| 130 | + |
| 131 | + def dualize(self, grad_w, target_norm=1.0): |
103 | 132 | if self.mass > 0: |
104 | 133 | m0, m1 = self.children |
105 | | - w0 = Vector(w[:m0.length]) |
106 | | - w1 = Vector(w[m0.length:]) |
107 | | - m0.regularize(w0, strength=m0.mass / self.mass * strength / m1.sensitivity) |
108 | | - m1.regularize(w1, strength=m1.mass / self.mass * strength) |
109 | | - |
| 134 | + grad_w0, grad_w1 = grad_w[:m0.atoms], grad_w[m0.atoms:] |
| 135 | + d_w0 = m0.dualize(grad_w0, target_norm = target_norm * m0.mass / self.mass / m1.sensitivity) |
| 136 | + d_w1 = m1.dualize(grad_w1, target_norm = target_norm * m1.mass / self.mass) |
| 137 | + d_w = d_w0 + d_w1 |
| 138 | + else: |
| 139 | + d_w = [0 * grad_weight for grad_weight in grad_w] |
| 140 | + return d_w |
110 | 141 |
|
111 | 142 | class TupleModule(Module): |
112 | | - def __init__(self, tuple_of_modules): |
| 143 | + def __init__(self, python_tuple_of_modules): |
113 | 144 | super().__init__() |
114 | | - self.children = tuple_of_modules |
115 | | - self.length = sum(child.length for child in self.children) |
116 | | - self.mass = sum(child.mass for child in self.children) |
117 | | - self.sensitivity = sum(child.sensitivity for child in self.children) |
118 | | - |
| 145 | + self.children = python_tuple_of_modules |
| 146 | + self.atoms = sum(m.atoms for m in self.children) |
| 147 | + self.bonds = sum(m.bonds for m in self.children) |
| 148 | + self.smooth = all(m.smooth for m in self.children) |
| 149 | + self.mass = sum(m.mass for m in self.children) |
| 150 | + self.sensitivity = sum(m.sensitivity for m in self.children) |
| 151 | + |
119 | 152 | def forward(self, x, w): |
120 | | - output = [] |
121 | | - for child in self.children: |
122 | | - w_child = w[:child.length] |
123 | | - output.append(child.forward(x, w_child)) |
124 | | - w = w[child.length:] |
125 | | - return output |
126 | | - |
127 | | - def initialize(self, device, dtype=torch.float32): |
128 | | - vector = Vector() |
129 | | - for child in self.children: |
130 | | - vector &= child.initialize(device, dtype=dtype) |
131 | | - return vector |
132 | | - |
133 | | - def normalize(self, w, target_norm=1): |
| 153 | + output_list = [] |
| 154 | + act_list = [] |
| 155 | + for m in self.children: |
| 156 | + output, act = m.forward(x, w[:m.atoms]) |
| 157 | + output_list.append(output) |
| 158 | + act_list += act |
| 159 | + w = w[m.atoms:] |
| 160 | + return output_list, act_list |
| 161 | + |
| 162 | + def backward(self, w, acts, grad_output): |
| 163 | + grad_w = [] |
| 164 | + grad_input = 0 |
| 165 | + for m, grad_output_m in zip(self.children, grad_output): |
| 166 | + grad_w_m, grad_input_m = m.backward(w[:m.atoms], acts[:m.atoms+m.bonds], grad_output_m) |
| 167 | + grad_w += grad_w_m |
| 168 | + grad_input += grad_input_m |
| 169 | + w = w[m.atoms:] |
| 170 | + acts = acts[m.atoms+m.bonds:] |
| 171 | + return grad_w, grad_input |
| 172 | + |
| 173 | + def initialize(self, key): |
| 174 | + w = [] |
| 175 | + for m in self.children: |
| 176 | + key, subkey = jax.random.split(key) |
| 177 | + w.append(m.initialize(subkey)) |
| 178 | + return w |
| 179 | + |
| 180 | + def project(self, w): |
| 181 | + projected_w = [] |
| 182 | + for m in self.children: |
| 183 | + projected_w_m = m.project(w[:m.atoms]) |
| 184 | + projected_w.append(projected_w_m) |
| 185 | + w = w[m.atoms:] |
| 186 | + return projected_w |
| 187 | + |
| 188 | + def dualize(self, grad_w, target_norm=1.0): |
134 | 189 | if self.mass > 0: |
135 | | - for child in self.children: |
136 | | - w_child = Vector(w[:child.length]) |
137 | | - child.normalize(w_child, target_norm=child.mass / self.mass * target_norm) |
138 | | - w = Vector(w[child.length:]) |
| 190 | + d_w = [] |
| 191 | + for m in self.children: |
| 192 | + grad_w_m = grad_w[:m.atoms] |
| 193 | + d_w_m = m.dualize(grad_w_m, target_norm = target_norm * m.mass / self.mass) |
| 194 | + d_w += d_w_m |
| 195 | + grad_w = grad_w[m.atoms:] |
139 | 196 | else: |
140 | | - w *= 0 |
141 | | - |
142 | | - def regularize(self, w, strength): |
143 | | - if self.mass > 0: |
144 | | - for child in self.children: |
145 | | - w_child = Vector(w[:child.length]) |
146 | | - child.regularize(w_child, strength=child.mass / self.mass * strength) |
147 | | - w = Vector(w[child.length:]) |
148 | | - |
| 197 | + d_w = [0 * grad_weight for grad_weight in grad_w] |
| 198 | + return d_w |
149 | 199 |
|
150 | | -class Mul(Module): |
151 | | - def __init__(self, alpha): |
| 200 | +class Identity(Bond): |
| 201 | + def __init__(self): |
152 | 202 | super().__init__() |
153 | | - self.mass = 0 |
154 | | - self.sensitivity = abs(alpha) |
155 | | - self.length = 0 |
156 | | - self.initialize = lambda device, dtype : Vector() |
157 | | - self.normalize = lambda w, target_norm : None |
158 | | - self.regularize = lambda w, strength : None |
159 | | - self.alpha = alpha |
160 | | - |
161 | | - def forward(self, x, _): |
162 | | - if isinstance(x, list): |
163 | | - return [self.forward(xi, _) for xi in x] |
164 | | - else: |
165 | | - return self.alpha * x |
| 203 | + self.smooth = True |
| 204 | + self.sensitivity = 1 |
166 | 205 |
|
| 206 | + def forward(self, x, w): |
| 207 | + return x, [None] |
| 208 | + |
| 209 | + def backward(self, w, acts, grad_output): |
| 210 | + return [], grad_output |
167 | 211 |
|
168 | | -class Add(Module): |
| 212 | +class Add(Bond): |
169 | 213 | def __init__(self): |
170 | 214 | super().__init__() |
171 | | - self.mass = 0 |
| 215 | + self.smooth = True |
172 | 216 | self.sensitivity = 1 |
173 | | - self.length = 0 |
174 | | - self.initialize = lambda device, dtype : Vector() |
175 | | - self.normalize = lambda w, target_norm : None |
176 | | - self.regularize = lambda w, strength : None |
177 | | - self.forward = lambda x, w : sum(x) |
| 217 | + |
| 218 | + def forward(self, x, w): |
| 219 | + return sum(x), [None] |
| 220 | + |
| 221 | + def backward(self, w, acts, grad_output): |
| 222 | + return [], (grad_output, grad_output) |
| 223 | + |
| 224 | +class Mul(Bond): |
| 225 | + def __init__(self, scalar): |
| 226 | + super().__init__() |
| 227 | + self.smooth = True |
| 228 | + self.sensitivity = scalar |
| 229 | + |
| 230 | + def forward(self, x, w): |
| 231 | + return x * self.sensitivity, [None] |
| 232 | + |
| 233 | + def backward(self, w, acts, grad_output): |
| 234 | + return [], grad_output * self.sensitivity |
0 commit comments