Skip to content

Commit bbb0014

Browse files
committed
Revert "rename"
This reverts commit 8d081d4.
1 parent 8d081d4 commit bbb0014

File tree

7 files changed

+174
-0
lines changed

7 files changed

+174
-0
lines changed
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.

modula/vector.py

Lines changed: 174 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,174 @@
1+
import torch
2+
3+
4+
class Vector:
5+
"""For doing algebra on lists of tensors.
6+
7+
An instance of Vector stores a list of tensors. Vectors can be
8+
added, subtracted, scalar-multiplied, elementwise-multiplied, etc.
9+
We also support in-place operations for efficiency.
10+
11+
Vectors are intended to store the weights of a neural net,
12+
allowing weight updates to be implemented using simple algebra.
13+
"""
14+
15+
def __init__(self, tensor_or_tensor_list = []):
16+
"""Stores a list of tensors."""
17+
if isinstance(tensor_or_tensor_list, torch.Tensor):
18+
self.tensor_list = [tensor_or_tensor_list]
19+
elif isinstance(tensor_or_tensor_list, list):
20+
self.tensor_list = tensor_or_tensor_list
21+
elif isinstance(tensor_or_tensor_list, tuple):
22+
self.tensor_list = tensor_or_tensor_list
23+
else:
24+
raise NotImplementedError
25+
26+
def __getitem__(self, item):
27+
"""Allows Vectors to be indexed and looped over."""
28+
return self.tensor_list[item]
29+
30+
def __len__(self):
31+
return len(self.tensor_list)
32+
33+
def grad(self):
34+
"""Returns the gradient list of this Vector."""
35+
return Vector([tensor.grad for tensor in self])
36+
37+
def zero_grad(self):
38+
"""Delete the gradients of this Vector."""
39+
for tensor in self:
40+
tensor.grad = None
41+
42+
def zero_nans(self):
43+
"""Set any nans or infs to zero, in-place."""
44+
for tensor in self:
45+
tensor.nan_to_num_(0,0,0)
46+
47+
@torch.no_grad()
48+
def all_reduce(self):
49+
"""Sums this vector over all workers"""
50+
for tensor in self:
51+
torch.distributed.all_reduce(tensor, torch.distributed.ReduceOp.SUM)
52+
53+
@torch.no_grad()
54+
def broadcast(self):
55+
"""Broadcasts this vector from worker zero to all other workers."""
56+
for tensor in self:
57+
torch.distributed.broadcast(tensor, src=0)
58+
59+
def __str__(self):
60+
"""Lets us print the Vector."""
61+
return str([t for t in self])
62+
63+
def __and__(self, other):
64+
"""Conatenate two Vectors."""
65+
return Vector(self.tensor_list + other.tensor_list)
66+
67+
def __iadd__(self, other):
68+
"""In-place add."""
69+
if len(self) == 0: return self
70+
if isinstance(other, Vector): other = other.tensor_list
71+
torch._foreach_add_(self.tensor_list, other)
72+
return self
73+
74+
def __add__(self, other):
75+
"""Add."""
76+
if len(self) == 0: return Vector()
77+
if isinstance(other, Vector): other = other.tensor_list
78+
new_list = torch._foreach_add(self.tensor_list, other)
79+
return Vector(new_list)
80+
81+
def __mul__(self, other):
82+
"""Multiply."""
83+
if len(self) == 0: return Vector()
84+
if isinstance(other, Vector): other = other.tensor_list
85+
new_list = torch._foreach_mul(self.tensor_list, other)
86+
return Vector(new_list)
87+
88+
def __rmul__(self, other):
89+
"""Multiply from the left."""
90+
return self * other
91+
92+
def __imul__(self, other):
93+
"""In-place multiply."""
94+
if len(self) == 0: return self
95+
if isinstance(other, Vector): other = other.tensor_list
96+
torch._foreach_mul_(self.tensor_list, other)
97+
return self
98+
99+
def __isub__(self, other):
100+
"""In-place subtract."""
101+
if len(self) == 0: return self
102+
if isinstance(other, Vector): other = other.tensor_list
103+
torch._foreach_sub_(self.tensor_list, other)
104+
return self
105+
106+
def __sub__(self, other):
107+
"""Subtract."""
108+
if len(self) == 0: return Vector()
109+
if isinstance(other, Vector): other = other.tensor_list
110+
new_list = torch._foreach_sub(self.tensor_list, other)
111+
return Vector(new_list)
112+
113+
def __itruediv__(self, other):
114+
"""In-place division."""
115+
if len(self) == 0: return self
116+
if isinstance(other, Vector): other = other.tensor_list
117+
torch._foreach_div_(self.tensor_list, other)
118+
return self
119+
120+
def __truediv__(self, other):
121+
"""Division."""
122+
if len(self) == 0: return Vector()
123+
if isinstance(other, Vector): other = other.tensor_list
124+
new_list = torch._foreach_div(self.tensor_list, other)
125+
return Vector(new_list)
126+
127+
def __ipow__(self, other):
128+
"""In-place power."""
129+
if len(self) == 0: return self
130+
if isinstance(other, Vector): other = other.tensor_list
131+
torch._foreach_pow_(self.tensor_list, other)
132+
return self
133+
134+
def __pow__(self, other):
135+
"""Power."""
136+
if len(self) == 0: return Vector()
137+
if isinstance(other, Vector): other = other.tensor_list
138+
new_list = torch._foreach_pow(self.tensor_list, other)
139+
return Vector(new_list)
140+
141+
142+
if __name__ == "__main__":
143+
144+
a = Vector([torch.tensor(2.0), torch.tensor(1.0)])
145+
146+
a *= 2; print(a)
147+
a += 1; print(a)
148+
a -= 1; print(a)
149+
a /= 2; print(a)
150+
a **= 2; print(a)
151+
152+
a = Vector([torch.tensor(2.0), torch.tensor(1.0)])
153+
154+
a **= a; print(a)
155+
a *= a; print(a)
156+
a /= a; print(a)
157+
a += a; print(a)
158+
a -= a; print(a)
159+
160+
a = Vector([torch.tensor(2.0), torch.tensor(1.0)])
161+
162+
a = a * 2; print(a)
163+
a = a + 1; print(a)
164+
a = a - 1; print(a)
165+
a = a / 2; print(a)
166+
a = a ** 2; print(a)
167+
168+
a = Vector([torch.tensor(2.0), torch.tensor(1.0)])
169+
170+
a = a * a; print(a)
171+
a = a + a; print(a)
172+
a = a / a; print(a)
173+
a = a ** a; print(a)
174+
a = a - a; print(a)

0 commit comments

Comments
 (0)