diff --git a/elephant/spike_train_dissimilarity.py b/elephant/spike_train_dissimilarity.py index 3234f8916..064a36128 100644 --- a/elephant/spike_train_dissimilarity.py +++ b/elephant/spike_train_dissimilarity.py @@ -21,7 +21,7 @@ """ from __future__ import division, print_function, unicode_literals - +import warnings import numpy as np import quantities as pq from neo.core import SpikeTrain @@ -363,6 +363,18 @@ def van_rossum_distance(spiketrains, time_constant=1.0 * pq.s, sort=True): for i, j in np.ndindex(k_dist.shape): vr_dist[i, j] = ( k_dist[i, i] + k_dist[j, j] - k_dist[i, j] - k_dist[j, i]) + + # Clip small negative values + if np.any(vr_dist < 0): + warnings.warn( + "van_rossum_distance: very small negative values encountered; " + "setting them to zero. Potentially due to floating point error, " + "which can occur if spike times are represented as small floating " + "point values (e.g., in seconds). A possible way to prevent this " + "warning is to use a time unit with better numerical precision, " + "e.g., from seconds to milliseconds.", RuntimeWarning) + vr_dist = np.maximum(vr_dist, 0.0) + return np.sqrt(vr_dist) diff --git a/elephant/test/test_spike_train_dissimilarity.py b/elephant/test/test_spike_train_dissimilarity.py index 4619d4bba..c856ad878 100644 --- a/elephant/test/test_spike_train_dissimilarity.py +++ b/elephant/test/test_spike_train_dissimilarity.py @@ -14,7 +14,7 @@ import elephant.kernels as kernels from elephant.spike_train_generation import StationaryPoissonProcess import elephant.spike_train_dissimilarity as stds - +import warnings from elephant.datasets import download_datasets @@ -73,7 +73,6 @@ def setUp(self): self.tau7 = 0.01 * s self.q7 = 1.0 / self.tau7 self.t = np.linspace(0, 200, 20000001) * ms - def test_wrong_input(self): self.assertRaises(TypeError, stds.victor_purpura_distance, [self.array1, self.array2], self.q3) @@ -600,6 +599,27 @@ def test_van_rossum_distance(self): [self.st21], self.tau3)[0, 0], 0) self.assertEqual(len(stds.van_rossum_distance([], self.tau3)), 0) + def test_van_rossum_distance_regression_small_negative_values(self): + """ + Regression test for issue #679 + Very small negative value in van_rossum_distance function. + Occurs due to floating point precision when + spike times are represented as small values + These values should be clipped to zero to avoid nans. + """ + + st24 = SpikeTrain([0.1782, 0.2286, 0.2804, 0.4972, 0.5504], + units='s', t_stop=4.0) + tau8 = 0.1 * s + # Check small negative values edge case + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + result = stds.van_rossum_distance([st24, st24], tau8) + self.assertTrue(any("very small negative values encountered" + in str(warn.message) for warn in w)) + self.assertEqual(result[0, 1], 0.0) + self.assertFalse(np.any(np.isnan(result))) + if __name__ == '__main__': unittest.main()