Skip to content

Commit 12d8e03

Browse files
committed
Sparsity for s monitor
1 parent 8b1cb4e commit 12d8e03

File tree

3 files changed

+24
-15
lines changed

3 files changed

+24
-15
lines changed

bindsnet/evaluation/evaluation.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,9 @@ def assign_labels(
4444
indices = torch.nonzero(labels == i).view(-1)
4545

4646
# Compute average firing rates for this label.
47+
selected_spikes = torch.index_select(spikes, dim=0, index=torch.tensor(indices))
4748
rates[:, i] = alpha * rates[:, i] + (
48-
torch.sum(spikes[indices], 0) / n_labeled
49+
torch.sum(selected_spikes, 0) / n_labeled
4950
)
5051

5152
# Compute proportions of spike activity per class.
@@ -111,6 +112,8 @@ def all_activity(
111112

112113
# Sum over time dimension (spike ordering doesn't matter).
113114
spikes = spikes.sum(1)
115+
if spikes.is_sparse:
116+
spikes = spikes.to_dense()
114117

115118
rates = torch.zeros((n_samples, n_labels), device=spikes.device)
116119
for i in range(n_labels):
@@ -152,6 +155,8 @@ def proportion_weighting(
152155

153156
# Sum over time dimension (spike ordering doesn't matter).
154157
spikes = spikes.sum(1)
158+
if spikes.is_sparse:
159+
spikes = spikes.to_dense()
155160

156161
rates = torch.zeros((n_samples, n_labels), device=spikes.device)
157162
for i in range(n_labels):

bindsnet/network/monitors.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ def __init__(
4545
time: Optional[int] = None,
4646
batch_size: int = 1,
4747
device: str = "cpu",
48+
sparse: Optional[bool] = False
4849
):
4950
# language=rst
5051
"""
@@ -62,6 +63,7 @@ def __init__(
6263
self.time = time
6364
self.batch_size = batch_size
6465
self.device = device
66+
self.sparse = sparse
6567

6668
# if time is not specified the monitor variable accumulate the logs
6769
if self.time is None:
@@ -98,11 +100,12 @@ def record(self) -> None:
98100
for v in self.state_vars:
99101
data = getattr(self.obj, v).unsqueeze(0)
100102
# self.recording[v].append(data.detach().clone().to(self.device))
101-
self.recording[v].append(
102-
torch.empty_like(data, device=self.device, requires_grad=False).copy_(
103-
data, non_blocking=True
104-
)
103+
record = torch.empty_like(data, device=self.device, requires_grad=False).copy_(
104+
data, non_blocking=True
105105
)
106+
if self.sparse:
107+
record = record.to_sparse()
108+
self.recording[v].append(record)
106109
# remove the oldest element (first in the list)
107110
if self.time is not None:
108111
self.recording[v].pop(0)

examples/mnist/batch_eth_mnist.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@
147147
spikes = {}
148148
for layer in set(network.layers):
149149
spikes[layer] = Monitor(
150-
network.layers[layer], state_vars=["s"], time=int(time / dt), device=device
150+
network.layers[layer], state_vars=["s"], time=int(time / dt), device=device, sparse=True
151151
)
152152
network.add_monitor(spikes[layer], name="%s_spikes" % layer)
153153

@@ -165,7 +165,8 @@
165165
perf_ax = None
166166
voltage_axes, voltage_ims = None, None
167167

168-
spike_record = torch.zeros((update_interval, int(time / dt), n_neurons), device=device)
168+
spike_record = [torch.zeros((batch_size, int(time / dt), n_neurons), device=device).to_sparse() for _ in range(update_interval // batch_size)]
169+
spike_record_idx = 0
169170

170171
# Train the network.
171172
print("\nBegin training...")
@@ -197,12 +198,13 @@
197198
# Convert the array of labels into a tensor
198199
label_tensor = torch.tensor(labels, device=device)
199200

201+
spike_record_tensor = torch.cat(spike_record, dim=0)
200202
# Get network predictions.
201203
all_activity_pred = all_activity(
202-
spikes=spike_record, assignments=assignments, n_labels=n_classes
204+
spikes=spike_record_tensor, assignments=assignments, n_labels=n_classes
203205
)
204206
proportion_pred = proportion_weighting(
205-
spikes=spike_record,
207+
spikes=spike_record_tensor,
206208
assignments=assignments,
207209
proportions=proportions,
208210
n_labels=n_classes,
@@ -240,7 +242,7 @@
240242

241243
# Assign labels to excitatory layer neurons.
242244
assignments, proportions, rates = assign_labels(
243-
spikes=spike_record,
245+
spikes=spike_record_tensor,
244246
labels=label_tensor,
245247
n_labels=n_classes,
246248
rates=rates,
@@ -261,11 +263,10 @@
261263

262264
# Add to spikes recording.
263265
s = spikes["Ae"].get("s").permute((1, 0, 2))
264-
spike_record[
265-
(step * batch_size)
266-
% update_interval : (step * batch_size % update_interval)
267-
+ s.size(0)
268-
] = s
266+
spike_record[spike_record_idx] = s
267+
spike_record_idx += 1
268+
if spike_record_idx == len(spike_record):
269+
spike_record_idx = 0
269270

270271
# Get voltage recording.
271272
exc_voltages = exc_voltage_monitor.get("v")

0 commit comments

Comments
 (0)