Skip to content

Commit 91d52ad

Browse files
committed
pytorch --> jax
1 parent b3a89a2 commit 91d52ad

File tree

8 files changed

+292
-478
lines changed

8 files changed

+292
-478
lines changed

modula/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from . import abstract
12
from . import atom
23
from . import bond
34
from . import compound
5+
from . import error

modula/abstract.py

Lines changed: 185 additions & 128 deletions
Original file line numberDiff line numberDiff line change
@@ -1,177 +1,234 @@
1-
import copy
2-
import torch
3-
4-
from modula.vector import Vector
1+
import jax
52

63
class Module:
74
def __init__(self):
8-
self.mass = None
9-
self.sensitivity = None
10-
self.length = None
115
self.children = []
12-
6+
7+
self.atoms = None # number of atoms: int
8+
self.bonds = None # number of bonds: int
9+
self.smooth = None # is this module smooth?: bool
10+
self.sensitivity = None # input Lipschitz estimate: float > 0
11+
self.mass = None # proportional contribution of module toward feature learning of any supermodule: float >= 0
12+
13+
def __str__(self):
14+
string = self.__class__.__name__
15+
string += f"\n...consists of {self.atoms} atoms and {self.bonds} bonds"
16+
string += f"\n...{'smooth' if self.smooth else 'non-smooth'}"
17+
string += f"\n...input sensitivity is {self.sensitivity}"
18+
string += f"\n...contributes proportion {self.mass} to feature learning of any supermodule"
19+
return string
20+
21+
def tare(self, absolute=1.0, relative=None):
22+
if relative is None:
23+
self.tare(relative = absolute / self.mass)
24+
else:
25+
self.mass *= relative
26+
for m in self.children:
27+
m.tare(relative = relative)
28+
29+
def jit(self):
30+
self.forward = jax.jit(self.forward)
31+
self.backward = jax.jit(self.backward)
32+
self.project = jax.jit(self.project)
33+
self.dualize = jax.jit(self.dualize)
34+
1335
def forward(self, x, w):
36+
# Input and weight list --> output and list of internal activations.
1437
raise NotImplementedError
1538

16-
def initialize(self, device, dtype):
39+
def backward(self, w, grad_output):
40+
# Weight list and output gradient --> weight gradient list and input gradient.
1741
raise NotImplementedError
1842

19-
def normalize(self, w, target_norm):
43+
def initialize(self, key):
44+
# Return a weight list.
2045
raise NotImplementedError
2146

22-
def regularize(self, w, strength):
47+
def project(self, w):
48+
# Return a weight list.
2349
raise NotImplementedError
2450

25-
def tare(self, absolute=1, relative=None):
26-
if relative is not None:
27-
self.mass *= relative
28-
for child in self.children:
29-
child.tare(relative = relative)
30-
else:
31-
self.tare(relative = absolute / self.mass)
32-
33-
def print_submodules(self):
34-
for child in self.children:
35-
child.print_submodules()
36-
37-
def __str__(self):
38-
return f"Module of mass {self.mass} and sensitivity {self.sensitivity}."
39-
40-
def __call__(self, x, w):
41-
return self.forward(x, w)
51+
def dualize(self, grad_w, target_norm):
52+
# Weight gradient list and number --> normalized weight gradient list
53+
raise NotImplementedError
4254

4355
def __matmul__(self, other):
44-
if isinstance(other, tuple): other = TupleModule(other)
4556
return CompositeModule(self, other)
4657

47-
def __rmatmul__(self, other):
48-
if isinstance(other, tuple): other = TupleModule(other)
49-
return other @ self
50-
5158
def __add__(self, other):
52-
return Add() @ (self, other)
59+
return Add() @ TupleModule((self, other))
5360

54-
def __mul__(self, other):
55-
assert other != 0, "cannot multiply a module by zero"
56-
return self @ Mul(other)
61+
def __rmul__(self, scalar):
62+
return Mul(scalar) @ self
5763

58-
def __rmul__(self, other):
59-
assert other != 0, "cannot multiply a module by zero"
60-
return Mul(other) @ self
64+
def __call__(self, x, w):
65+
return self.forward(x, w)
6166

62-
def __truediv__(self, other):
63-
assert other != 0, "cannot divide a module by zero"
64-
return self * (1/other)
67+
class Atom(Module):
68+
def __init__(self):
69+
super().__init__()
70+
self.atoms = 1
71+
self.bonds = 0
6572

66-
def __pow__(self, other):
67-
assert other >= 0 and other % 1 == 0, "nonnegative integer powers only"
68-
if other > 0:
69-
return copy.deepcopy(self) @ self ** (other - 1)
70-
else:
71-
return Mul(1.0)
73+
class Bond(Module):
74+
def __init__(self):
75+
super().__init__()
76+
self.atoms = 0
77+
self.bonds = 1
78+
self.mass = 0
7279

80+
def initialize(self, key):
81+
return []
82+
83+
def project(self, w):
84+
return []
85+
86+
def dualize(self, grad_w, target_norm=1.0):
87+
return []
7388

7489
class CompositeModule(Module):
7590
def __init__(self, m1, m0):
7691
super().__init__()
7792
self.children = (m0, m1)
78-
self.length = m0.length + m1.length
79-
self.mass = m0.mass + m1.mass
80-
self.sensitivity = m1.sensitivity * m0.sensitivity
81-
93+
94+
self.atoms = m0.atoms + m1.atoms
95+
self.bonds = m0.bonds + m1.bonds
96+
self.smooth = m0.smooth and m1.smooth
97+
self.mass = m0.mass + m1.mass
98+
self.sensitivity = m0.sensitivity * m1.sensitivity
99+
82100
def forward(self, x, w):
83101
m0, m1 = self.children
84-
w0 = w[:m0.length]
85-
w1 = w[m0.length:]
86-
return m1.forward(m0.forward(x, w0), w1)
102+
w0 = w[:m0.atoms]
103+
w1 = w[m0.atoms:]
104+
x0, activations0 = m0.forward(x, w0)
105+
x1, activations1 = m1.forward(x0, w1)
106+
return x1, activations0 + activations1
87107

88-
def initialize(self, device, dtype=torch.float32):
108+
def initialize(self, key):
89109
m0, m1 = self.children
90-
return m0.initialize(device, dtype=dtype) & m1.initialize(device, dtype=dtype)
110+
key, subkey = jax.random.split(key)
111+
return m0.initialize(key) + m1.initialize(subkey)
91112

92-
def normalize(self, w, target_norm=1):
93-
if self.mass > 0:
94-
m0, m1 = self.children
95-
w0 = Vector(w[:m0.length])
96-
w1 = Vector(w[m0.length:])
97-
m0.normalize(w0, target_norm=m0.mass / self.mass * target_norm / m1.sensitivity)
98-
m1.normalize(w1, target_norm=m1.mass / self.mass * target_norm)
99-
else:
100-
w *= 0
113+
def project(self, w):
114+
m0, m1 = self.children
115+
w0 = w[:m0.atoms]
116+
w1 = w[m0.atoms:]
117+
return m0.project(w0) + m1.project(w1)
118+
119+
def backward(self, w, acts, grad_output):
120+
m0, m1 = self.children
121+
w0 = w[:m0.atoms]
122+
w1 = w[m0.atoms:]
123+
acts0 = acts[:m0.atoms+m0.bonds]
124+
acts1 = acts[m0.atoms+m0.bonds:]
125+
126+
grad_w1, grad_input1 = m1.backward(w1, acts1, grad_output)
127+
grad_w0, grad_input0 = m0.backward(w0, acts0, grad_input1)
101128

102-
def regularize(self, w, strength):
129+
return grad_w0 + grad_w1, grad_input0
130+
131+
def dualize(self, grad_w, target_norm=1.0):
103132
if self.mass > 0:
104133
m0, m1 = self.children
105-
w0 = Vector(w[:m0.length])
106-
w1 = Vector(w[m0.length:])
107-
m0.regularize(w0, strength=m0.mass / self.mass * strength / m1.sensitivity)
108-
m1.regularize(w1, strength=m1.mass / self.mass * strength)
109-
134+
grad_w0, grad_w1 = grad_w[:m0.atoms], grad_w[m0.atoms:]
135+
d_w0 = m0.dualize(grad_w0, target_norm = target_norm * m0.mass / self.mass / m1.sensitivity)
136+
d_w1 = m1.dualize(grad_w1, target_norm = target_norm * m1.mass / self.mass)
137+
d_w = d_w0 + d_w1
138+
else:
139+
d_w = [0 * grad_weight for grad_weight in grad_w]
140+
return d_w
110141

111142
class TupleModule(Module):
112-
def __init__(self, tuple_of_modules):
143+
def __init__(self, python_tuple_of_modules):
113144
super().__init__()
114-
self.children = tuple_of_modules
115-
self.length = sum(child.length for child in self.children)
116-
self.mass = sum(child.mass for child in self.children)
117-
self.sensitivity = sum(child.sensitivity for child in self.children)
118-
145+
self.children = python_tuple_of_modules
146+
self.atoms = sum(m.atoms for m in self.children)
147+
self.bonds = sum(m.bonds for m in self.children)
148+
self.smooth = all(m.smooth for m in self.children)
149+
self.mass = sum(m.mass for m in self.children)
150+
self.sensitivity = sum(m.sensitivity for m in self.children)
151+
119152
def forward(self, x, w):
120-
output = []
121-
for child in self.children:
122-
w_child = w[:child.length]
123-
output.append(child.forward(x, w_child))
124-
w = w[child.length:]
125-
return output
126-
127-
def initialize(self, device, dtype=torch.float32):
128-
vector = Vector()
129-
for child in self.children:
130-
vector &= child.initialize(device, dtype=dtype)
131-
return vector
132-
133-
def normalize(self, w, target_norm=1):
153+
output_list = []
154+
act_list = []
155+
for m in self.children:
156+
output, act = m.forward(x, w[:m.atoms])
157+
output_list.append(output)
158+
act_list += act
159+
w = w[m.atoms:]
160+
return output_list, act_list
161+
162+
def backward(self, w, acts, grad_output):
163+
grad_w = []
164+
grad_input = 0
165+
for m, grad_output_m in zip(self.children, grad_output):
166+
grad_w_m, grad_input_m = m.backward(w[:m.atoms], acts[:m.atoms+m.bonds], grad_output_m)
167+
grad_w += grad_w_m
168+
grad_input += grad_input_m
169+
w = w[m.atoms:]
170+
acts = acts[m.atoms+m.bonds:]
171+
return grad_w, grad_input
172+
173+
def initialize(self, key):
174+
w = []
175+
for m in self.children:
176+
key, subkey = jax.random.split(key)
177+
w.append(m.initialize(subkey))
178+
return w
179+
180+
def project(self, w):
181+
projected_w = []
182+
for m in self.children:
183+
projected_w_m = m.project(w[:m.atoms])
184+
projected_w.append(projected_w_m)
185+
w = w[m.atoms:]
186+
return projected_w
187+
188+
def dualize(self, grad_w, target_norm=1.0):
134189
if self.mass > 0:
135-
for child in self.children:
136-
w_child = Vector(w[:child.length])
137-
child.normalize(w_child, target_norm=child.mass / self.mass * target_norm)
138-
w = Vector(w[child.length:])
190+
d_w = []
191+
for m in self.children:
192+
grad_w_m = grad_w[:m.atoms]
193+
d_w_m = m.dualize(grad_w_m, target_norm = target_norm * m.mass / self.mass)
194+
d_w += d_w_m
195+
grad_w = grad_w[m.atoms:]
139196
else:
140-
w *= 0
141-
142-
def regularize(self, w, strength):
143-
if self.mass > 0:
144-
for child in self.children:
145-
w_child = Vector(w[:child.length])
146-
child.regularize(w_child, strength=child.mass / self.mass * strength)
147-
w = Vector(w[child.length:])
148-
197+
d_w = [0 * grad_weight for grad_weight in grad_w]
198+
return d_w
149199

150-
class Mul(Module):
151-
def __init__(self, alpha):
200+
class Identity(Bond):
201+
def __init__(self):
152202
super().__init__()
153-
self.mass = 0
154-
self.sensitivity = abs(alpha)
155-
self.length = 0
156-
self.initialize = lambda device, dtype : Vector()
157-
self.normalize = lambda w, target_norm : None
158-
self.regularize = lambda w, strength : None
159-
self.alpha = alpha
160-
161-
def forward(self, x, _):
162-
if isinstance(x, list):
163-
return [self.forward(xi, _) for xi in x]
164-
else:
165-
return self.alpha * x
203+
self.smooth = True
204+
self.sensitivity = 1
166205

206+
def forward(self, x, w):
207+
return x, [None]
208+
209+
def backward(self, w, acts, grad_output):
210+
return [], grad_output
167211

168-
class Add(Module):
212+
class Add(Bond):
169213
def __init__(self):
170214
super().__init__()
171-
self.mass = 0
215+
self.smooth = True
172216
self.sensitivity = 1
173-
self.length = 0
174-
self.initialize = lambda device, dtype : Vector()
175-
self.normalize = lambda w, target_norm : None
176-
self.regularize = lambda w, strength : None
177-
self.forward = lambda x, w : sum(x)
217+
218+
def forward(self, x, w):
219+
return sum(x), [None]
220+
221+
def backward(self, w, acts, grad_output):
222+
return [], (grad_output, grad_output)
223+
224+
class Mul(Bond):
225+
def __init__(self, scalar):
226+
super().__init__()
227+
self.smooth = True
228+
self.sensitivity = scalar
229+
230+
def forward(self, x, w):
231+
return x * self.sensitivity, [None]
232+
233+
def backward(self, w, acts, grad_output):
234+
return [], grad_output * self.sensitivity

0 commit comments

Comments
 (0)