diff --git a/numtraits.py b/numtraits.py index 299ed3c..b01d019 100644 --- a/numtraits.py +++ b/numtraits.py @@ -25,7 +25,7 @@ from __future__ import print_function -from traitlets import TraitType, TraitError +from traitlets import TraitType, TraitError, Undefined import numpy as np @@ -38,16 +38,13 @@ class NumericalTrait(TraitType): info_text = 'a numerical trait, either a scalar or a vector' def __init__(self, ndim=None, shape=None, domain=None, - default=None, convertible_to=None): - super(NumericalTrait, self).__init__() + default_value=Undefined, convertible_to=None, allow_none=False): + super(NumericalTrait, self).__init__(default_value=default_value,allow_none=allow_none) - # Just store all the construction arguments. + # Store the construction arguments. self.ndim = ndim self.shape = shape self.domain = domain - # TODO: traitlets supports a `default` argument in __init__(), we should - # probably link them together once we start using this. - self.default = default self.target_unit = convertible_to if self.target_unit is not None: @@ -65,7 +62,6 @@ def _check_args(self): raise TraitError("shape={0} and ndim={1} are inconsistent".format(self.shape, self.ndim)) def validate(self, obj, value): - # We proceed by checking whether Numpy tells us the value is a # scalar. If Numpy isscalar returns False, it could still be scalar # but be a Quantity with units, so we then extract the numerical diff --git a/test_numtraits.py b/test_numtraits.py index 5cd9ff7..e13916b 100644 --- a/test_numtraits.py +++ b/test_numtraits.py @@ -13,6 +13,9 @@ class ScalarProperties(HasTraits): d = NumericalTrait(ndim=0, domain='negative') e = NumericalTrait(ndim=0, domain='strictly-negative') f = NumericalTrait(ndim=0, domain=(3, 4)) + g = NumericalTrait(ndim=0, allow_none=True, default_value=None) + h = NumericalTrait(ndim=0, allow_none=True, default_value=1.23) + i = NumericalTrait(ndim=0, default_value=4.56) class TestScalar(object): @@ -79,6 +82,20 @@ def test_range(self): self.sp.f = 7 assert exc.value.args[0] == "f should be in the range [3:4]" + def test_nullable_default(self): + assert self.sp.g is None + assert self.sp.h == 1.23 + assert self.sp.i == 4.56 + self.sp.g = 1.2 + assert self.sp.g == 1.2 + self.sp.h = None + assert self.sp.h is None + self.sp.i = 1.23 + assert self.sp.i == 1.23 + with pytest.raises(TraitError) as exc: + self.sp.i = None + assert exc.value.args[0] == "i should be a scalar value" + class ArrayProperties(HasTraits):