Skip to content

Commit ddfd58f

Browse files
authored
Merge pull request #773 from apdavison/enhance-translation
ok, I thought this was a standalone PR, but some Brian2 tests will fail until the matching changes are made. Merging anyway, essentially for documentation purposes.
2 parents 97ba231 + cb5a48e commit ddfd58f

File tree

2 files changed

+27
-15
lines changed

2 files changed

+27
-15
lines changed

pyNN/standardmodels/__init__.py

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,14 @@
3636
# ==============================================================================
3737

3838

39+
def build_scaling_functions(pynn_name, sim_name, scale_factor):
40+
def f(**p):
41+
return p[pynn_name] * scale_factor
42+
def g(**p):
43+
return p[sim_name] / scale_factor
44+
return f, g
45+
46+
3947
def build_translations(*translation_list):
4048
"""
4149
Build a translation dictionary from a list of translations/transformations.
@@ -49,16 +57,19 @@ def build_translations(*translation_list):
4957
if len(item) == 2: # no transformation
5058
f = pynn_name
5159
g = sim_name
60+
type_ = "simple"
5261
elif len(item) == 3: # simple multiplicative factor
5362
scale_factor = item[2]
54-
f = "float(%g)*%s" % (scale_factor, pynn_name)
55-
g = "%s/float(%g)" % (sim_name, scale_factor)
63+
f, g = build_scaling_functions(pynn_name, sim_name, scale_factor)
64+
type_ = "scaled"
5665
elif len(item) == 4: # more complex transformation
5766
f = item[2]
5867
g = item[3]
68+
type_ = "computed"
5969
translations[pynn_name] = {'translated_name': sim_name,
6070
'forward_transform': f,
61-
'reverse_transform': g}
71+
'reverse_transform': g,
72+
'type': type_}
6273
return translations
6374

6475

@@ -133,21 +144,19 @@ def simple_parameters(self):
133144
"""Return a list of parameters for which there is a one-to-one
134145
correspondance between standard and native parameter values."""
135146
return [name for name in self.translations
136-
if self.translations[name]['forward_transform'] == name]
147+
if self.translations[name]['type'] == "simple"]
137148

138149
def scaled_parameters(self):
139150
"""Return a list of parameters for which there is a unit change between
140151
standard and native parameter values."""
141-
def scaling(trans):
142-
return (not callable(trans)) and ("float" in trans)
143152
return [name for name in self.translations
144-
if scaling(self.translations[name]['forward_transform'])]
153+
if self.translations[name]['type'] == "scaled"]
145154

146155
def computed_parameters(self):
147156
"""Return a list of parameters whose values must be computed from
148157
more than one other parameter."""
149158
return [name for name in self.translations
150-
if name not in self.simple_parameters() + self.scaled_parameters()]
159+
if self.translations[name]['type'] == "computed"]
151160

152161
def computed_parameters_include(self, parameter_names):
153162
return any(name in self.computed_parameters() for name in parameter_names)
@@ -203,12 +212,13 @@ def __getattr__(self, name):
203212
"e.g. source.amplitude = 0.5, or use 'set_parameters()' " \
204213
"e.g. source.set_parameters(amplitude=0.5)"
205214
raise AttributeError(err_msg)
215+
206216
try:
207-
val = self.__getattribute__(name)
208-
except AttributeError:
217+
val = self.get_parameters()[name]
218+
except KeyError:
209219
try:
210-
val = self.get_parameters()[name]
211-
except KeyError:
220+
val = self.__getattribute__(name)
221+
except AttributeError:
212222
raise errors.NonExistentParameterError(name,
213223
self.__class__.__name__,
214224
self.get_parameter_names())

test/unittests/test_standardmodels.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,15 @@ def test_build_translations():
1818
('c', 'C', 'c + a', 'C - A')
1919
)
2020
assert set(t.keys()) == set(['a', 'b', 'c'])
21-
assert set(t['a'].keys()) == set(['translated_name', 'forward_transform', 'reverse_transform'])
21+
assert set(t['a'].keys()) == set(['translated_name', 'forward_transform', 'reverse_transform', 'type'])
2222
assert t['a']['translated_name'] == 'A'
2323
assert t['a']['forward_transform'] == 'a'
2424
assert t['a']['reverse_transform'] == 'A'
2525
assert t['b']['translated_name'] == 'B'
26-
assert t['b']['forward_transform'] == 'float(1000)*b'
27-
assert t['b']['reverse_transform'] == 'B/float(1000)'
26+
assert callable(t['b']['forward_transform'])
27+
assert t['b']['forward_transform'](b=7) == 7000
28+
assert callable(t['b']['reverse_transform'])
29+
assert t['b']['reverse_transform'](B=7000) == 7
2830
assert t['c']['translated_name'] == 'C'
2931
assert t['c']['forward_transform'] == 'c + a'
3032
assert t['c']['reverse_transform'] == 'C - A'

0 commit comments

Comments
 (0)