Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion elephant/spike_train_dissimilarity.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
"""

from __future__ import division, print_function, unicode_literals

import warnings
import numpy as np
import quantities as pq
from neo.core import SpikeTrain
Expand Down Expand Up @@ -363,6 +363,18 @@ def van_rossum_distance(spiketrains, time_constant=1.0 * pq.s, sort=True):
for i, j in np.ndindex(k_dist.shape):
vr_dist[i, j] = (
k_dist[i, i] + k_dist[j, j] - k_dist[i, j] - k_dist[j, i])

# Clip small negative values
if np.any(vr_dist < 0):
warnings.warn(
"van_rossum_distance: very small negative values encountered; "
"setting them to zero. Potentially due to floating point error, "
"which can occur if spike times are represented as small floating "
"point values (e.g., in seconds). A possible way to prevent this "
"warning is to use a time unit with better numerical precision, "
"e.g., from seconds to milliseconds.", RuntimeWarning)
vr_dist = np.maximum(vr_dist, 0.0)

return np.sqrt(vr_dist)


Expand Down
24 changes: 22 additions & 2 deletions elephant/test/test_spike_train_dissimilarity.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import elephant.kernels as kernels
from elephant.spike_train_generation import StationaryPoissonProcess
import elephant.spike_train_dissimilarity as stds

import warnings
from elephant.datasets import download_datasets


Expand Down Expand Up @@ -73,7 +73,6 @@ def setUp(self):
self.tau7 = 0.01 * s
self.q7 = 1.0 / self.tau7
self.t = np.linspace(0, 200, 20000001) * ms

def test_wrong_input(self):
self.assertRaises(TypeError, stds.victor_purpura_distance,
[self.array1, self.array2], self.q3)
Expand Down Expand Up @@ -600,6 +599,27 @@ def test_van_rossum_distance(self):
[self.st21], self.tau3)[0, 0], 0)
self.assertEqual(len(stds.van_rossum_distance([], self.tau3)), 0)

def test_van_rossum_distance_regression_small_negative_values(self):
"""
Regression test for issue #679
Very small negative value in van_rossum_distance function.
Occurs due to floating point precision when
spike times are represented as small values
These values should be clipped to zero to avoid nans.
"""

st24 = SpikeTrain([0.1782, 0.2286, 0.2804, 0.4972, 0.5504],
units='s', t_stop=4.0)
tau8 = 0.1 * s
# Check small negative values edge case
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Although there is a single unit test for the function, it is clearer to add a specific regression unit test for the issue. We have examples in other modules, e.g., test_instantaneous_rate_regression_288 in elephant.statistics. It is good to point to the issue number and have a short comment.

with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
result = stds.van_rossum_distance([st24, st24], tau8)
self.assertTrue(any("very small negative values encountered"
in str(warn.message) for warn in w))
self.assertEqual(result[0, 1], 0.0)
self.assertFalse(np.any(np.isnan(result)))


if __name__ == '__main__':
unittest.main()
Loading