diff --git a/numtraits.py b/numtraits.py index 45d059f..b6d416d 100644 --- a/numtraits.py +++ b/numtraits.py @@ -25,7 +25,9 @@ from __future__ import print_function -from traitlets import TraitType, TraitError +import warnings + +from traitlets import TraitType, TraitError, Undefined import numpy as np @@ -38,16 +40,18 @@ class NumericalTrait(TraitType): info_text = 'a numerical trait, either a scalar or a vector' def __init__(self, ndim=None, shape=None, domain=None, - default=None, convertible_to=None): - super(NumericalTrait, self).__init__() + default_value=Undefined, convertible_to=None, default=Undefined): + if default is not Undefined: + if default_value is not Undefined: + raise TypeError('Cannot set default and default_value simultaneously') + warnings.warn(DeprecationWarning('`default` has been renamed to `default_value`')) + default_value = default + super(NumericalTrait, self).__init__(default_value=default_value) # Just store all the construction arguments. self.ndim = ndim self.shape = shape self.domain = domain - # TODO: traitlets supports a `default` argument in __init__(), we should - # probably link them together once we start using this. - self.default = default self.target_unit = convertible_to if self.target_unit is not None: @@ -97,7 +101,7 @@ def validate(self, obj, value): if self.ndim is not None: if self.ndim == 0: - if not is_scalar: + if not is_scalar and num_value.ndim: raise TraitError("{0} should be a scalar value".format(self.name)) if self.ndim > 0: diff --git a/test_numtraits.py b/test_numtraits.py index 5cd9ff7..c4af70f 100644 --- a/test_numtraits.py +++ b/test_numtraits.py @@ -13,6 +13,7 @@ class ScalarProperties(HasTraits): d = NumericalTrait(ndim=0, domain='negative') e = NumericalTrait(ndim=0, domain='strictly-negative') f = NumericalTrait(ndim=0, domain=(3, 4)) + g = NumericalTrait(ndim=0, default_value=2) class TestScalar(object): @@ -79,6 +80,14 @@ def test_range(self): self.sp.f = 7 assert exc.value.args[0] == "f should be in the range [3:4]" + def test_scalar_quantities(self): + """ Tests for issue #14. + """ + quantities = pytest.importorskip("quantities") + self.sp.a = 1*quantities.m + + def test_default_value(self): + assert self.sp.g == 2 class ArrayProperties(HasTraits): @@ -90,6 +99,7 @@ class ArrayProperties(HasTraits): f = NumericalTrait(domain=(3, 4), ndim=1) g = NumericalTrait(shape=(3, 4)) + class TestArray(object): def setup_method(self, method):