From f70c9cd4d909b7de0add151c6db30cc184867cfc Mon Sep 17 00:00:00 2001 From: Jan Nolten Date: Fri, 7 Nov 2025 08:44:03 +0100 Subject: [PATCH 01/29] Added schemas Folder with __init__.py --- elephant/schemas/__init__.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 elephant/schemas/__init__.py diff --git a/elephant/schemas/__init__.py b/elephant/schemas/__init__.py new file mode 100644 index 000000000..e69de29bb From e007679c1bad33b1870c498c454945b61f7e17e6 Mon Sep 17 00:00:00 2001 From: Jan Nolten Date: Fri, 7 Nov 2025 08:47:22 +0100 Subject: [PATCH 02/29] Added field_validator.py to group repeated validation --- elephant/schemas/field_validator.py | 244 ++++++++++++++++++++++++++++ 1 file changed, 244 insertions(+) create mode 100644 elephant/schemas/field_validator.py diff --git a/elephant/schemas/field_validator.py b/elephant/schemas/field_validator.py new file mode 100644 index 000000000..dbd01726e --- /dev/null +++ b/elephant/schemas/field_validator.py @@ -0,0 +1,244 @@ +import numpy as np +import quantities as pq +import neo +import elephant +from enum import Enum +from typing import Any +import warnings + +def get_length(obj) -> int: + """ + Return the length (number of elements) of various supported datatypes: + - list + - numpy.ndarray + - pq.Quantity + - neo.SpikeTrain + + Returns + ------- + int + The number of elements or spikes in the object. + + Raises + ------ + TypeError + If the object type is not supported. + """ + if obj is None: + raise ValueError("Cannot get length of None") + + if isinstance(obj, elephant.trials.Trials): + return len(obj.trials) + elif isinstance(obj, elephant.conversion.BinnedSpikeTrain): + return obj.n_bins + elif isinstance(obj, neo.SpikeTrain): + return len(obj) + elif isinstance(obj, pq.Quantity): + return obj.size + elif isinstance(obj, np.ndarray): + return obj.size + elif isinstance(obj, (list,tuple)): + return len(obj) + + + + else: + raise TypeError( + f"Unsupported type for length computation: {type(obj).__name__}" + ) + +def is_sorted(obj) -> bool: + if obj is None: + raise ValueError("Cannot check sortedness of None") + + if isinstance(obj, (list, np.ndarray, pq.Quantity)): + arr = np.asarray(obj) + return np.all(arr[:-1] <= arr[1:]) + elif isinstance(obj, neo.SpikeTrain): + arr = obj.magnitude # Get the underlying numpy array of spike times + return np.all(arr[:-1] <= arr[1:]) + return False + +def is_matrix(obj) -> bool: + if obj is None: + raise ValueError("Cannot check matrix of None") + if isinstance(obj, (list, np.ndarray, pq.Quantity)): + arr = np.asarray(obj) + return arr.ndim >= 2 + elif isinstance(obj, neo.SpikeTrain): + arr = obj.magnitude # Get the underlying numpy array of spike times + return arr.ndim >= 2 + return False + +def validate_covariance_matrix_rank_deficient(obj, info): + """ + Check if the covariance matrix of the given object is rank deficient. + Should work for elephant.trials.Trials, list of neo.core.spiketrainlist.SpikeTrainList or list of list of neo.core.SpikeTrain. + """ + return obj + +def validate_type( + value, + info, + allowed_types: tuple, + allow_none: bool, +): + """Generic type validation helper.""" + if value is None: + if allow_none: + return value + raise ValueError(f"{info.field_name} cannot be None") + + if not isinstance(value, allowed_types): + raise TypeError(f"{info.field_name} must be one of {allowed_types}, not {type(value).__name__}") + return value + +def validate_length( + value, + info: str, + min_length: int, + warning: bool +): + if min_length>0: + if get_length(value) < min_length: + if warning: + warnings.warn(f"{info.field_name} has less than {min_length} elements", UserWarning) + else: + raise ValueError(f"{info.field_name} must contain at least {min_length} elements") + return value + +def validate_type_length(value, info, allowed_types: tuple, allow_none: bool, min_length: int, warning: bool = False): + validate_type(value, info, allowed_types, allow_none) + if value is not None: + validate_length(value, info, min_length, warning) + return value + +def validate_array_content(value, info, allowed_types: tuple, allow_none: bool, min_length: int, allowed_content_types: tuple, min_length_content: int = 0): + validate_type_length(value, info, allowed_types, allow_none, min_length) + for i, item in enumerate(value): + if not isinstance(item, allowed_content_types): + raise TypeError(f"Element {i} in {info.field_name} must be {allowed_content_types}, not {type(item).__name__}") + if min_length_content > 0 and get_length(item) >= min_length_content: + hasContentLength = True + if(min_length_content > 0 and not hasContentLength): + raise ValueError(f"{info.field_name} must contain at least one element with at least {min_length_content} elements") + + return value + +# ---- Specialized validation helpers ---- + +def validate_spiketrain(value, info, allowed_types=(list, neo.SpikeTrain, pq.Quantity, np.ndarray), allow_none = False, min_length = 1, check_sorted = False): + validate_type_length(value, info, allowed_types, allow_none, min_length) + if(check_sorted): + if value is not None and not is_sorted(value): + warnings.warn(f"{info.field_name} is not sorted", UserWarning) + if(isinstance(value, neo.SpikeTrain)): + if value.t_start is not None and value.t_stop is not None: + if value.t_start > value.t_stop: + raise ValueError(f"{info.field_name} has t_start > t_stop") + return value + +def validate_spiketrains(value, info, allowed_types = (list,), allow_none = False, min_length = 1, allowed_content_types = (list, neo.SpikeTrain, pq.Quantity, np.ndarray), min_length_content = 0): + validate_array_content(value, info, allowed_types, allow_none, min_length, allowed_content_types, min_length_content) + return value + +def validate_spiketrains_matrix(value, info, allowed_types = (elephant.trials.Trials, list[neo.core.spiketrainlist.SpikeTrainList], list[list[neo.core.SpikeTrain]]), allow_none = False, min_length = 1, check_rank_deficient = False): + if isinstance(value, list): + validate_spiketrains(value, info, allowed_content_types=(neo.core.spiketrainlist,list[neo.core.SpikeTrain],)) + else: + validate_type(value, info, (elephant.trials.Trials,), allow_none=False) + if check_rank_deficient: + return validate_covariance_matrix_rank_deficient(value, info) + return value + +def validate_time(value, info, allowed_types=(float, pq.Quantity) ,allow_none=True): + if(isinstance(value, np.ndarray) and value.size==1): + value = value.item() + + validate_type(value, info, allowed_types, allow_none) + return value + +def validate_quantity(value, info, allow_none=False): + validate_type(value, info, (pq.Quantity,), allow_none) + return value + +def validate_time_intervals(value, info, allowed_types = (list, pq.Quantity, np.ndarray), allow_none = False, min_length=0, check_matrix = False): + validate_type_length(value, info, allowed_types, allow_none, min_length) + if check_matrix: + if value is not None and is_matrix(value): + raise ValueError(f"{info.field_name} is not allowed to be a matrix") + return value + +def validate_array(value, info, allowed_types=(list, np.ndarray) , allow_none=False, min_length=1, allowed_content_types = None, min_length_content = 0): + if allowed_content_types is None: + validate_type_length(value, info, allowed_types, allow_none, min_length) + else: + validate_array_content(value, info, allowed_types, allow_none, min_length, allowed_content_types, min_length_content) + return value + +def validate_binned_spiketrain(value, info, allowed_types=(elephant.conversion.BinnedSpikeTrain,), allow_none=False, min_length=1): + validate_type_length(value, info, allowed_types, allow_none, min_length, warning=True) + if value is not None and isinstance(value, elephant.conversion.BinnedSpikeTrain): + spmat = value.sparse_matrix + + # Check for empty spike trains + n_spikes_per_row = spmat.sum(axis=1) + if n_spikes_per_row.min() == 0: + warnings.warn( + f'Detected empty spike trains (rows) in the {info.field_name}.', UserWarning) + return value + +def validate_dict_enum_types(value : dict[Enum, Any], info, typeDictionary: dict[Enum, type]): + for key, val in value.items(): + if not isinstance(val, typeDictionary[key]): + raise TypeError(f"Value for key {key} in {info.field_name} must be of type {typeDictionary[key].__name__}, not {type(val).__name__}") + return value + +def validate_key_in_tuple(value : str, info, t: tuple): + if value not in t: + raise ValueError(f"{info}:{value} is not in the options {t}") + return value + + +# ---- Model validation helpers ---- + +def model_validate_spiketrains_same_t_start_stop(spiketrain, t_start, t_stop, name: str = "spiketrains", warning: bool = False): + if(t_start is None or t_stop is None): + first = True + for i, item in enumerate(spiketrain): + if first: + t_start = item.t_start + t_stop = item.t_stop + first = False + else: + if t_start is None and item.t_start != t_start: + if warning: + warnings.warn(f"{name} has different t_start values among its elements", UserWarning) + else: + raise ValueError(f"{name} has different t_start values among its elements") + if t_stop is None and item.t_stop != t_stop: + if warning: + warnings.warn(f"{name} has different t_stop values among its elements", UserWarning) + else: + raise ValueError(f"{name} has different t_stop values among its elements") + else: + if t_start>t_stop: + raise ValueError(f"{name} has t_start > t_stop") + +def model_validate_spiketrains_sam_t_start_stop(spiketrain_i, spiketrain_j): + if spiketrain_i.t_start != spiketrain_j.t_start: + raise ValueError("spiketrain_i and spiketrain_j need to have the same t_start") + if spiketrain_i.t_stop != spiketrain_j.t_stop: + raise ValueError("spiketrain_i and spiketrain_j need to have the same t_stop") + +def model_validate_time_intervals_with_nan(time_intervals , with_nan, name: str = "time_intervals"): + if get_length(time_intervals)<2: + if(with_nan): + warnings.warn(f"{name} has less than two entries so a np.Nan will be generated", UserWarning) + else: + raise ValueError(f"{name} has less than two entries") + +def model_validate_binned_spiketrain_fast(binned_spiketrain, fast, name: str = "binned_spiketrain"): + if(fast and np.max(binned_spiketrain.shape) > np.iinfo(np.int32).max): + raise MemoryError(f"{name} is too large for fast=True option") + \ No newline at end of file From ccac5cc51eecbd584f23270f8ede61846b62991a Mon Sep 17 00:00:00 2001 From: Jan Nolten Date: Fri, 7 Nov 2025 08:48:37 +0100 Subject: [PATCH 03/29] Added field_serializer.py to group repeated serialization --- elephant/schemas/field_serializer.py | 9 +++++++++ 1 file changed, 9 insertions(+) create mode 100644 elephant/schemas/field_serializer.py diff --git a/elephant/schemas/field_serializer.py b/elephant/schemas/field_serializer.py new file mode 100644 index 000000000..d7d2587bd --- /dev/null +++ b/elephant/schemas/field_serializer.py @@ -0,0 +1,9 @@ +import quantities as pq + +def serialize_quantity(value: pq.Quantity) -> dict: + if value is None: + return None + return { + "value": value.magnitude, + "unit": value.dimensionality + } \ No newline at end of file From 11c2a35bf2a4ea0518ecce96fc4a3f2bd9df3666 Mon Sep 17 00:00:00 2001 From: Jan Nolten Date: Fri, 7 Nov 2025 08:53:06 +0100 Subject: [PATCH 04/29] added function_validator.py to validate elephant functions --- elephant/schemas/function_validator.py | 27 ++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) create mode 100644 elephant/schemas/function_validator.py diff --git a/elephant/schemas/function_validator.py b/elephant/schemas/function_validator.py new file mode 100644 index 000000000..8d9f60686 --- /dev/null +++ b/elephant/schemas/function_validator.py @@ -0,0 +1,27 @@ +from functools import wraps +from inspect import signature +from pydantic import BaseModel + +def validate_with(model_class: type[BaseModel]): + """ + A decorator that validates the inputs of a function using a Pydantic model. + Works for both positional and keyword arguments. + """ + def decorator(func): + sig = signature(func) + + @wraps(func) + def wrapper(*args, **kwargs): + # Bind args & kwargs to function parameters + bound = sig.bind_partial(*args, **kwargs) + bound.apply_defaults() + data = bound.arguments + + # Validate using Pydantic + validated = model_class(**data) + + # Call function with validated data unpacked + return func(**validated.model_dump()) + + return wrapper + return decorator \ No newline at end of file From 46ff38b3cbcbfc7a615955dd64586861e97202c6 Mon Sep 17 00:00:00 2001 From: Jan Nolten Date: Fri, 7 Nov 2025 09:03:54 +0100 Subject: [PATCH 05/29] Added Pydantic Models for statistics --- elephant/schemas/schema_statistics.py | 343 ++++++++++++++++++++++++++ 1 file changed, 343 insertions(+) create mode 100644 elephant/schemas/schema_statistics.py diff --git a/elephant/schemas/schema_statistics.py b/elephant/schemas/schema_statistics.py new file mode 100644 index 000000000..9314e2521 --- /dev/null +++ b/elephant/schemas/schema_statistics.py @@ -0,0 +1,343 @@ +import quantities as pq +import numpy as np +from typing import ( + Any, + Union, + Self, + Optional +) +from pydantic import ( + BaseModel, + Field, + field_validator, + model_validator, + field_serializer +) +import neo +from enum import Enum +import elephant + +from elephant.kernels import Kernel +import elephant.schemas.field_validator as fv +import elephant.schemas.field_serializer as fs + +import warnings + +class PydanticMeanFiringRate(BaseModel): + """ + PyDantic Class to wrap the elephant.statistics.mean_firing_rate function + with additional type checking and json_schema by PyDantic. + """ + spiketrain: Any = Field(None, description="SpikeTrain Object") + t_start: Optional[Any] = Field(None, description="Start time") + t_stop: Optional[Any] = Field(None, description="Stop time") + axis: Optional[int] = Field(None, description="Axis of calculation") + + @field_validator("spiketrain") + @classmethod + def validate_spiketrain(cls, v, info): + return fv.validate_spiketrain(v, info, allow_none=True) + + @field_validator("t_start", "t_stop") + @classmethod + def validate_time(cls, v, info): + return fv.validate_time(v, info) + + @model_validator(mode="after") + def validate_model(self) -> Self: + if isinstance(self.spiketrain, (np.ndarray, list)): + if isinstance(self.t_start, pq.Quantity) or isinstance(self.t_stop, pq.Quantity): + raise TypeError("spiketrain is a np.ndarray or list but t_start or t_stop is pq.Quantity") + elif not (isinstance(self.t_start, pq.Quantity) and isinstance(self.t_stop, pq.Quantity)): + raise TypeError("spiketrain is a neo.SpikeTrain or pq.Quantity but t_start or t_stop is not pq.Quantity") + return self + +class PydanticInstantaneousRate(BaseModel): + """ + PyDantic Class to wrap the elephant.statistics.instantaneous_rate function + with additional type checking and json_schema by PyDantic. + """ + + class KernelOptions(Enum): + auto = "auto" + + spiketrains: Any = Field(..., description="Input spike train(s)") + sampling_period: Any = Field(..., gt=0, description="Time stamp resolution of spike times") + kernel: Union[KernelOptions, Any] = Field(KernelOptions.auto, description="Kernel for convolution") + cutoff: Optional[float] = Field(5.0, gt=0, description="cutoff of probability distribution") + t_start: Optional[Any] = Field(None, ge=0, description="Start time") + t_stop: Optional[Any] = Field(None, gt=0, description="Stop time") + trim: Optional[bool] = Field(False, description="Only return region of convolved signal") + center_kernel: Optional[bool] = Field(True, description="Center the kernel on spike") + border_correction: Optional[bool] = Field(False, description="Apply border correction") + pool_trials: Optional[bool] = Field(False, description="Calc firing rates averaged over trials when spiketrains is Trials object") + pool_spike_trains: Optional[bool] = Field(False, description="Calc firing rates averaged over spiketrains") + + + @field_validator("spiketrains") + @classmethod + def validate_spiketrains(cls, v, info): + if(isinstance(v, list)): + return fv.validate_spiketrains(v, info, allowed_types=(list,), allowed_content_types=(neo.SpikeTrain,)) + return fv.validate_spiketrain(v, info, allowed_types=(neo.SpikeTrain, elephant.trials.Trials)) + + @field_validator("sampling_period") + @classmethod + def validate_quantity(cls, v, info): + return fv.validate_quantity(v, info) + + @field_validator("kernel") + @classmethod + def validate_kernel(cls, v, info): + return fv.validate_type(v, info, allowed_types=(cls.KernelOptions, Kernel), allow_none=False) + + @field_validator("t_start", "t_stop") + @classmethod + def validate_time(cls, v, info): + return fv.validate_quantity(v, info, allow_none=True) + + @model_validator(mode="after") + def validate_model(self) -> Self: + if(isinstance(self.kernel, Kernel) and self.cutoff < self.kernel.min_cutoff): + warnings.warn(f"cutoff {self.cutoff} is smaller than the minimum cutoff {self.kernel.min_cutoff} of the kernel", UserWarning) + fv.model_validate_spiketrains_same_t_start_stop(self.spiketrains, self.t_start, self.t_stop, warning=True) + return self + +class PydanticTimeHistogram(BaseModel): + """ + PyDantic Class to wrap the elephant.statistics.time_histogram function + with additional type checking and json_schema by PyDantic. + """ + + class OutputOptions(Enum): + counts = "counts" + mean = "mean" + rate = "rate" + + spiketrains: list = Field(..., description="List of Spiketrains") + bin_size: Any = Field(..., description="Width histogram's time bins") + t_start: Optional[Any] = Field(None, description="Start time") + t_stop: Optional[Any] = Field(None, description="Stop time") + output: Optional[OutputOptions] = Field(OutputOptions.counts, description="Normalization") + binary: Optional[bool] = Field(False, description="To binary") + + @field_validator("spiketrains") + @classmethod + def validate_spiketrains(cls, v, info): + return fv.validate_spiketrains(v, info, allowed_content_types=(neo.SpikeTrain,)) + + @field_validator("bin_size") + @classmethod + def validate_quantity(cls, v, info): + return fv.validate_quantity(v, info) + + @field_validator("t_start", "t_stop") + @classmethod + def validate_quantity_none(cls, v, info): + return fv.validate_quantity(v, info, allow_none=True) + + @model_validator(mode="after") + def validate_model(self) -> Self: + fv.model_validate_spiketrains_same_t_start_stop(self.spiketrains, self.t_start, self.t_stop, warning=True) + return self + +class PydanticOptimalKernelBandwidth(BaseModel): + """ + PyDantic Class to wrap the elephant.statistics.optimal_kernel_bandwidth function + with additional type checking and json_schema by PyDantic. + """ + + spiketimes: Any = Field(..., description="Sequence of spike times(ASC)") + times: Optional[Any] = Field(None, description="Time at which kernel bandwidth") + bandwidth: Optional[Any] = Field(None, description="Vector of kernal bandwidth") + bootstrap: Optional[bool] = Field(False, description="Use Bootstrap") + + @field_validator("spiketimes") + @classmethod + def validate_ndarray(cls, v, info): + return fv.validate_array(v, info, allowed_types=(np.ndarray,)) + + @field_validator("times", "bandwidth") + @classmethod + def validate_ndarray_none(cls, v, info): + return fv.validate_array(v, info, allowed_types=(np.ndarray,), allow_none=True) + +class PydanticIsi(BaseModel): + """ + PyDantic Class to wrap the elephant.statistics.isi function + with additional type checking and json_schema by PyDantic. + """ + spiketrain: Any = Field(..., description="SpikeTrain Object (sorted)") + axis: Optional[int] = Field(-1, description="Difference Axis") + + @field_validator("spiketrain") + @classmethod + def validate_spiketrain_sorted(cls, v, info): + return fv.validate_spiketrain(v, info, check_sorted=True) + +class PydanticCv(BaseModel): + """ + PyDantic Class to wrap the elephant.statistics.cv function + with additional type checking and json_schema by PyDantic. + """ + class NanPolicyOptions(Enum): + propagate = "propagate" + omit = "omit" + _raise = "raise" + + a: Any = Field(..., description="Input array") + axis: Union[int, None] = Field(0, description="Compute statistic axis") + nan_policy: NanPolicyOptions = Field(NanPolicyOptions.propagate, description="How handle input NaNs") + ddof: Optional[int] = Field(0, ge=0, description="Delta Degrees Of Freedom") + keepdims: Optional[bool] = Field(False, description="leave reduced axes in one-dimensional result") + + @field_validator("a") + @classmethod + def validate_array(cls, v, info): + return fv.validate_array(v, info) + +class PydanticCv2(BaseModel): + """ + PyDantic Class to wrap the elephant.statistics.cv2 function + with additional type checking and json_schema by PyDantic. + """ + + time_intervals: Any = Field(..., description="Vector of time intervals") + with_nan: Optional[bool] = Field(False, description="Do not Raise warning on short spike train") + + @field_validator("time_intervals") + @classmethod + def validate_time_intervals(cls, v, info): + return fv.validate_time_intervals(v, info, check_matrix=True) + + @model_validator(mode="after") + def validate_model(self) -> Self: + fv.model_validate_time_intervals_with_nan(self.time_intervals, self.with_nan) + return self + +class PydanticLv(BaseModel): + """ + PyDantic Class to wrap the elephant.statistics.lv function + with additional type checking and json_schema by PyDantic. + """ + + time_intervals: Any = Field(..., description="Vector of time intervals") + with_nan: Optional[bool] = Field(False, description="Do not Raise warning on short spike train") + + @field_validator("time_intervals") + @classmethod + def validate_time_intervals(cls, v, info): + return fv.validate_time_intervals(v, info, check_matrix=True) + + @model_validator(mode="after") + def validate_model(self) -> Self: + fv.model_validate_time_intervals_with_nan(self.time_intervals, self.with_nan) + return self + +class PydanticLvr(BaseModel): + """ + PyDantic Class to wrap the elephant.statistics.lvr function + with additional type checking and json_schema by PyDantic. + """ + + time_intervals: Any = Field(..., description="Vector of time intervals (default units: ms)") + R: Any = Field(default_factory=lambda: 5. * pq.ms, ge=0, description="Refractoriness constant (default quantity: ms)") + with_nan: Optional[bool] = Field(False, description="Do not Raise warning on short spike train") + + @field_serializer("R", mode='plain') + def serialize_quantity(self, v): + return fs.serialize_quantity(v) + + @field_validator("time_intervals") + @classmethod + def validate_time_intervals(cls, v, info): + return fv.validate_time_intervals(v, info, check_matrix=True) + + @field_validator("R") + @classmethod + def validate_R(cls, v, info): + fv.validate_type(v, info, (pq.Quantity, int, float), allow_none=False) + if(not isinstance(v, pq.Quantity)): + warnings.warn("R does not have any units so milliseconds are assumed", UserWarning) + return v + + @model_validator(mode="after") + def validate_model(self) -> Self: + fv.model_validate_time_intervals_with_nan(self.time_intervals, self.with_nan) + return self + +class PydanticFanofactor(BaseModel): + """ + PyDantic Class to wrap the elephant.statistics.fanofactor function + with additional type checking and json_schema by PyDantic. + """ + spiketrains: list = Field(..., description="List of Spiketrains") + warn_tolerance: Any = Field(default_factory=lambda: 0.1 * pq.ms, ge=0, description="Warn tolerence of variations") + + @field_serializer("warn_tolerance", mode='plain') + def serialize_quantity(self, v): + return fs.serialize_quantity(v) + + @field_validator("spiketrains") + @classmethod + def validate_spiketrains(cls, v, info): + return fv.validate_spiketrains(v, info) + + @field_validator("warn_tolerance") + @classmethod + def validate_quantity(cls, v, info): + return fv.validate_quantity(v, info) + +class PydanticComplexityPdf(BaseModel): + """ + PyDantic Class to wrap the elephant.statistics.complexity_pdf function + with additional type checking and json_schema by PyDantic. + """ + spiketrains: list = Field(..., description="List of Spiketrains") + bin_size: Any = Field(..., description="Width histogram's time bins") + + @field_validator("spiketrains") + @classmethod + def validate_spiketrains(cls, v, info): + fv.validate_spiketrains(v, info, allowed_content_types=(neo.SpikeTrain,)) + fv.model_validate_spiketrains_same_t_start_stop(v, None, None) + return v + + @field_validator("bin_size") + @classmethod + def validate_quantity(cls, v, info): + return fv.validate_quantity(v, info) + +class PydanticComplexityInit(BaseModel): + spiketrains: list = Field(..., description="List of neo.SpikeTrain objects with common t_start/t_stop") + sampling_rate: Optional[Any] = Field(None, description="Sampling rate (1/time)") + bin_size: Optional[Any] = Field(None, description="Width of histogram bins") + binary: Optional[bool] = Field(True, description="If True count neurons, else total spikes") + spread: Optional[int] = Field(0, ge=0, description="Number of bins for synchronous spikes (>=0)") + tolerance: Optional[float] = Field(1e-8, description="Tolerance for rounding errors") + + @field_validator("spiketrains") + @classmethod + def validate_spiketrains(cls, v, info): + fv.validate_spiketrains(v, info, allowed_content_types=(neo.SpikeTrain,)) + fv.model_validate_spiketrains_same_t_start_stop(v, None, None) + return v + + @field_validator("bin_size") + @classmethod + def validate_bin_size(cls, v, info): + return fv.validate_quantity(v, info, allow_none=True) + + @field_validator("sampling_rate") + @classmethod + def validate_sampling_rate(cls, v, info): + fv.validate_quantity(v, info, allow_none=True) + if v is None: + warnings.warn("no sampling rate is supplied. This may lead to rounding errors when using the epoch to slice spike trains", UserWarning) + return v + + @model_validator(mode="after") + def check_rate_or_bin(self): + if self.sampling_rate is None and self.bin_size is None: + raise ValueError("Either sampling_rate or bin_size must be set") + return self \ No newline at end of file From 5fb18dc2e743c69f7e784141120bd07faa95b92a Mon Sep 17 00:00:00 2001 From: Jan Nolten Date: Fri, 7 Nov 2025 09:05:37 +0100 Subject: [PATCH 06/29] Added Pydantic Models for spike_train_correlation --- .../schemas/schema_spike_train_correlation.py | 152 ++++++++++++++++++ 1 file changed, 152 insertions(+) create mode 100644 elephant/schemas/schema_spike_train_correlation.py diff --git a/elephant/schemas/schema_spike_train_correlation.py b/elephant/schemas/schema_spike_train_correlation.py new file mode 100644 index 000000000..46ca22611 --- /dev/null +++ b/elephant/schemas/schema_spike_train_correlation.py @@ -0,0 +1,152 @@ +import quantities as pq +import numpy as np +from typing import ( + Any, + Union, + Self, + Optional +) +from pydantic import ( + BaseModel, + Field, + field_validator, + model_validator, + field_serializer +) +import neo +from enum import Enum + +import elephant.schemas.field_validator as fv +import elephant.schemas.field_serializer as fs + +class PydanticCovariance(BaseModel): + """ + PyDantic Class to wrap the elephant.spike_train_correlation.covariance function + with additional type checking and json_schema by PyDantic. + """ + + binned_spiketrain: Any = Field(..., description="Binned spike train") + binary: Optional[bool] = Field(False, description="Use binary binned vectors") + fast: Optional[bool] = Field(True, description="Use faster implementation") + + @field_validator("binned_spiketrain") + @classmethod + def validate_binned_spiketrain(cls, v, info): + return fv.validate_binned_spiketrain(v, info) + + @model_validator(mode="after") + def validate_model(self) -> Self: + fv.model_validate_binned_spiketrain_fast(self.binned_spiketrain, self.fast) + return self + + +class PydanticCorrelationCoefficient(BaseModel): + """ + PyDantic Class to wrap the elephant.spike_train_correlation.correlation_coefficient function + with additional type checking and json_schema by PyDantic. + """ + + binned_spiketrain: Any = Field(..., description="Binned spike train") + binary: Optional[bool] = Field(False, description="Use binary binned vectors") + fast: Optional[bool] = Field(True, description="Use faster implementation") + + @field_validator("binned_spiketrain") + @classmethod + def validate_binned_spiketrain(cls, v, info): + return fv.validate_binned_spiketrain(v, info) + + @model_validator(mode="after") + def validate_model(self) -> Self: + fv.model_validate_binned_spiketrain_fast(self.binned_spiketrain, self.fast) + return self + + +class PydanticCrossCorrelationHistogram(BaseModel): + """ + PyDantic Class to wrap the elephant.spike_train_correlation.cross_correlation_histogram function + with additional type checking and json_schema by PyDantic. + """ + + class WindowOptions(Enum): + full = "full" + valid = "valid" + + class MethodOptions(Enum): + speed = "speed" + memory = "memory" + + binned_spiketrain_i: Any = Field(..., description="Binned spike train i") + binned_spiketrain_j: Any = Field(..., description="Binned spike train j") + window: Optional[Union[WindowOptions, list[int]]] = Field(WindowOptions.full, description="Window") + border_correction: Optional[bool] = Field(False, description="Correct border effect") + binary: Optional[bool] = Field(False, description="Count spike falling same bin as one") + kernel: Optional[Any] = Field(None, description="array containing a smoothing kernel") + method: Optional[MethodOptions] = Field(MethodOptions.speed, description="Method of calculating") + cross_correlation_coefficient: Optional[bool] = Field(False, description="Normalize CCH") + + @field_validator("binned_spiketrain_i", "binned_spiketrain_j") + @classmethod + def validate_binned_spiketrain(cls, v, info): + return fv.validate_binned_spiketrain(v, info) + + @field_validator("kernel") + @classmethod + def validate_kernel(cls, v, info): + return fv.validate_array(v, info, allowed_types=(np.ndarray,), allow_none=True) + + +class PydanticSpikeTimeTilingCoefficient(BaseModel): + """ + PyDantic Class to wrap the elephant.spike_train_correlation.spike_time_tiling_coefficient function + with additional type checking and json_schema by PyDantic. + """ + + spiketrain_i: Any = Field(..., description="Spike train Object i") + spiketrain_j: Any = Field(..., description="Spike train Object j (same T_start and same t_stop)") + dt: Any = Field(default_factory=lambda: 0.005 * pq.s, description="Synchronicity window") + + @field_serializer("dt", mode='plain') + def serialize_quantity(self, value: pq.Quantity): + return fs.serialize_quantity(value) + + @field_validator("spiketrain_i", "spiketrain_j") + @classmethod + def validate_spiketrain(cls, v, info): + # require specifically neo.core.SpikeTrain for this validator + return fv.validate_spiketrain(v, info, allowed_types=(neo.core.SpikeTrain,)) + + @field_validator("dt") + @classmethod + def validate_dt(cls, v, info): + return fv.validate_quantity(v, info) + + @model_validator(mode="after") + def check_correctTypeCombination(self) -> Self: + fv.model_validate_spiketrains_sam_t_start_stop(self.spiketrain_i, self.spiketrain_j) + return self + + +class PydanticSpikeTrainTimescale(BaseModel): + """ + PyDantic Class to wrap the elephant.spike_train_correlation.spike_train_timescale function + with additional type checking and json_schema by PyDantic. + """ + + binned_spiketrain: Any = Field(..., description="Binned spike train") + max_tau: Any = Field(..., description="Maximal integration time") + + @field_validator("binned_spiketrain") + @classmethod + def validate_binned_spiketrain(cls, v, info): + return fv.validate_binned_spiketrain(v, info) + + @field_validator("max_tau") + @classmethod + def validate_max_tau(cls, v, info): + return fv.validate_quantity(v, info) + + @model_validator(mode="after") + def check_correctTypeCombination(self) -> Self: + if self.max_tau % self.binned_spiketrain.bin_size != 0: + raise ValueError("max_tau has to be a multiple of binned_spiketrain.bin_size") + return self \ No newline at end of file From d5324bfe21e4726ec3a133de27599d8779031550 Mon Sep 17 00:00:00 2001 From: Jan Nolten Date: Fri, 7 Nov 2025 09:06:41 +0100 Subject: [PATCH 07/29] Added Pydantic Models for spike_train_synchrony --- .../schemas/schema_spike_train_synchrony.py | 64 +++++++++++++++++++ 1 file changed, 64 insertions(+) create mode 100644 elephant/schemas/schema_spike_train_synchrony.py diff --git a/elephant/schemas/schema_spike_train_synchrony.py b/elephant/schemas/schema_spike_train_synchrony.py new file mode 100644 index 000000000..4c967f874 --- /dev/null +++ b/elephant/schemas/schema_spike_train_synchrony.py @@ -0,0 +1,64 @@ +import quantities as pq +import numpy as np +from typing import ( + Any, + Optional +) +from pydantic import ( + BaseModel, + Field, + field_validator, + field_serializer +) +import neo +from enum import Enum +from elephant.schemas.schema_statistics import PydanticComplexityInit + +import elephant.schemas.field_validator as fv +import elephant.schemas.field_serializer as fs + + +class PydanticSpikeContrast(BaseModel): + """ + PyDantic Class to wrap the elephant.spike_train_synchrony.spike_contrast function + with additional type checking and json_schema by PyDantic. + """ + + spiketrains: list = Field(..., description="List of Spiketrains") + t_start: Optional[Any] = Field(None, description="Start time") + t_stop: Optional[Any] = Field(None, description="Stop time") + min_bin: Optional[Any] = Field(default_factory=lambda: 10. * pq.ms, description="Min value for bin_min") + bin_shrink_factortime: Optional[float] = Field(0.9, description="Shrink bin size multiplier", ge=0., le=1.) + return_trace: Optional[bool] = Field(False, description="Return history of spike-contrast synchrony") + + @field_serializer("min_bin", mode='plain') + def serialize_quantity(self, value: pq.Quantity): + return fs.serialize_quantity(value) + + @field_validator("spiketrains") + @classmethod + def validate_spiketrains(cls, v, info): + return fv.validate_spiketrains(v, info, allowed_content_types=(neo.SpikeTrain,), min_length=2, min_length_content=2) + + @field_validator("t_start", "t_stop") + @classmethod + def validate_time(cls, v, info): + return fv.validate_quantity(v, info, allow_none=True) + + @field_validator("min_bin") + @classmethod + def validate_min_bin(cls, v, info): + return fv.validate_quantity(v, info) + + +class PydanticSynchrotoolInit(PydanticComplexityInit): + pass + +class PydanticSynchrotoolDeleteSynchrofacts(BaseModel): + class ModeOptions(Enum): + delete = "delete" + extract = "extract" + + threshold: int = Field(..., gt=1, description="Threshold for deletion of spikes") + in_place: Optional[bool] = Field(False, description="Make modification in place") + mode: Optional[ModeOptions] = Field(ModeOptions.delete, description="Inversion of mask for deletion") \ No newline at end of file From c99b6c016f78ca8853a3e58253478227716f77e3 Mon Sep 17 00:00:00 2001 From: Jan Nolten Date: Fri, 7 Nov 2025 09:49:46 +0100 Subject: [PATCH 08/29] Original arguments are passed into the function --- elephant/schemas/function_validator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/elephant/schemas/function_validator.py b/elephant/schemas/function_validator.py index 8d9f60686..bcfa20a66 100644 --- a/elephant/schemas/function_validator.py +++ b/elephant/schemas/function_validator.py @@ -21,7 +21,7 @@ def wrapper(*args, **kwargs): validated = model_class(**data) # Call function with validated data unpacked - return func(**validated.model_dump()) + return func(*args, **kwargs) return wrapper return decorator \ No newline at end of file From cb533cc25d9d40d4eec389cf5bf9073df68fe255 Mon Sep 17 00:00:00 2001 From: Jan Nolten Date: Fri, 7 Nov 2025 09:50:40 +0100 Subject: [PATCH 09/29] Added pytest.ini to .gitignore --- .gitignore | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.gitignore b/.gitignore index 6f6651146..f73ca14e3 100644 --- a/.gitignore +++ b/.gitignore @@ -89,3 +89,5 @@ ignored/ # neo logs **/logs + +pytest.ini \ No newline at end of file From 71b7dc0ff2d776d27aacee1af0340dd62d7a7157 Mon Sep 17 00:00:00 2001 From: Jan Nolten Date: Fri, 7 Nov 2025 11:01:33 +0100 Subject: [PATCH 10/29] Added tests and option to skip validation --- elephant/schemas/function_validator.py | 11 +- elephant/test/test_schemas.py | 294 +++++++++++++++++++++++++ 2 files changed, 302 insertions(+), 3 deletions(-) create mode 100644 elephant/test/test_schemas.py diff --git a/elephant/schemas/function_validator.py b/elephant/schemas/function_validator.py index bcfa20a66..c19304685 100644 --- a/elephant/schemas/function_validator.py +++ b/elephant/schemas/function_validator.py @@ -12,16 +12,21 @@ def decorator(func): @wraps(func) def wrapper(*args, **kwargs): + + if kwargs.pop("not_validate", False): + # skip validation, call inner function directly + return func(*args, **kwargs) + # Bind args & kwargs to function parameters bound = sig.bind_partial(*args, **kwargs) bound.apply_defaults() data = bound.arguments # Validate using Pydantic - validated = model_class(**data) + model_class(**data) - # Call function with validated data unpacked + # Call function return func(*args, **kwargs) - + wrapper._is_validate_with = True return wrapper return decorator \ No newline at end of file diff --git a/elephant/test/test_schemas.py b/elephant/test/test_schemas.py new file mode 100644 index 000000000..7297a6b1c --- /dev/null +++ b/elephant/test/test_schemas.py @@ -0,0 +1,294 @@ + +import pytest +import quantities as pq +import neo +import numpy as np + +import elephant + +from pydantic import ValidationError + +from elephant.schemas.schema_statistics import *; +from elephant.schemas.schema_spike_train_correlation import *; +from elephant.schemas.schema_spike_train_synchrony import *; + + +def test_model_json_schema(): + # Just test that json_schema generation runs without error for all models + model_classes = [ + PydanticCovariance, + PydanticCorrelationCoefficient, + PydanticCrossCorrelationHistogram, + PydanticSpikeTimeTilingCoefficient, + PydanticSpikeTrainTimescale, + PydanticMeanFiringRate, + PydanticInstantaneousRate, + PydanticTimeHistogram, + PydanticOptimalKernelBandwidth, + PydanticIsi, + PydanticCv, + PydanticCv2, + PydanticLv, + PydanticLvr, + PydanticFanofactor, + PydanticComplexityPdf, + PydanticComplexityInit, + PydanticSpikeContrast, + PydanticSynchrotoolInit, + PydanticSynchrotoolDeleteSynchrofacts, + ] + for cls in model_classes: + schema = cls.model_json_schema() + assert isinstance(schema, dict) + +""" +Checking for consistent behavior between Elephant functions and Pydantic models. +Tests bypass validate_with decorator if it is already implemented for that function +so consistency is checked correctly +""" + +def call_elephant_function(elephant_fn, kwargs): + if hasattr(elephant_fn, "_is_validate_with"): + kwargs["not_validate"]=True + elephant_fn(**kwargs) + else: + elephant_fn(**kwargs) + +def assert_both_succeed_consistently(elephant_fn, model_cls, kwargs): + """Call both the Elephant function and the Pydantic model with the same kwargs. + Assert both complete without raising exceptions. + + Parameters + - elephant_fn: callable to invoke with kwargs + - model_cls: Pydantic model class to instantiate with kwargs + - kwargs: dict of keyword arguments to pass to both + """ + try: + call_elephant_function(elephant_fn, kwargs) + except Exception as e: + assert False, f"Elephant function raised an exception: {e}" + + try: + model_cls(**kwargs) + except Exception as e: + assert False, f"Pydantic model raised an exception: {e}" + +def assert_both_warn_consistently(elephant_fn, model_cls, kwargs): + """Call both the Elephant function and the Pydantic model with the same kwargs. + Assert both raise warnings. + + Parameters + - elephant_fn: callable to invoke with kwargs + - model_cls: Pydantic model class to instantiate with kwargs + - kwargs: dict of keyword arguments to pass to both + """ + with pytest.warns(Warning) as w1: + call_elephant_function(elephant_fn, kwargs) + with pytest.warns(Warning) as w2: + model_cls(**kwargs) + + +def assert_both_raise_consistently(elephant_fn, model_cls, kwargs, *, same_type=False, expected_exception=None): + """Call both the Elephant function and the Pydantic model with the same kwargs. + Assert both raise, and if requested assert they raise the same exception type. + + Uses pytest.raises to capture exceptions so failures are reported with pytest's + native formatting while still allowing comparison of exception objects. + + Parameters + - elephant_fn: callable to invoke with kwargs + - model_cls: Pydantic model class to instantiate with kwargs + - kwargs: dict of keyword arguments to pass to both + - same_type: if True assert the raised exception classes are identical + - expected_exception: optional exception type that both must be instances of + """ + with pytest.raises(Exception) as e1: + call_elephant_function(elephant_fn, kwargs) + with pytest.raises(Exception) as e2: + model_cls(**kwargs) + + exc1 = e1.value + exc2 = e2.value + + if expected_exception is not None: + assert isinstance(exc1, expected_exception), ( + f"Elephant raised {type(exc1)}, expected {expected_exception}") + assert isinstance(exc2, expected_exception), ( + f"Pydantic raised {type(exc2)}, expected {expected_exception}") + + if same_type: + if(type(exc1) is type(exc2)): + return + + if (isinstance(exc1, (ValueError, TypeError)) and isinstance(exc2, (ValidationError, AttributeError))): + return + + assert False, ( + f"Different exception types: Elephant={type(exc1)}, Pydantic={type(exc2)}. " + f"Elephant exc: {exc1}; Pydantic exc: {exc2}") + +@pytest.fixture +def make_list(): + return [0.01, 0.02, 0.05] + +@pytest.fixture +def make_ndarray(make_list): + return np.array(make_list) + +@pytest.fixture +def make_pq_single_quantity(): + return 0.05 * pq.s + +@pytest.fixture +def make_pq_multiple_quantity(make_ndarray): + return make_ndarray * pq.s + +@pytest.fixture +def make_spiketrain(make_pq_multiple_quantity): + return neo.core.SpikeTrain(make_pq_multiple_quantity, t_start=0 * pq.s, t_stop=0.1 * pq.s) + +@pytest.fixture +def make_spiketrains(make_spiketrain): + return [make_spiketrain, make_spiketrain] + +@pytest.fixture +def make_binned_spiketrain(make_spiketrain): + return elephant.conversion.BinnedSpikeTrain(make_spiketrain, bin_size=0.01 * pq.s) + +@pytest.fixture +def make_analog_signal(): + n2 = 300 + n0 = 100000 - n2 + return neo.AnalogSignal(np.array([10] * n2 + [0] * n0).reshape(n0 + n2, 1) * pq.dimensionless, sampling_period=1 * pq.s) + +@pytest.fixture +def fixture(request): + return request.getfixturevalue(request.param) + + +@pytest.mark.parametrize("elephant_fn,model_cls", [ + (elephant.statistics.mean_firing_rate, PydanticMeanFiringRate), + (elephant.statistics.isi, PydanticIsi), +]) +@pytest.mark.parametrize("fixture", [ + "make_list", + "make_spiketrain", + "make_ndarray", + "make_pq_multiple_quantity", +], indirect=["fixture"]) +def test_valid_spiketrain_input(elephant_fn, model_cls, fixture): + valid = {"spiketrain": fixture} + assert_both_succeed_consistently(elephant_fn, model_cls, valid) + + +@pytest.mark.parametrize("elephant_fn,model_cls", [ + (elephant.statistics.mean_firing_rate, PydanticMeanFiringRate), + (elephant.statistics.isi, PydanticIsi), +]) +@pytest.mark.parametrize("spiketrain", [ + 5, + "hello", +]) +def test_invalid_spiketrain(elephant_fn, model_cls, spiketrain): + invalid = {"spiketrain": spiketrain} + assert_both_raise_consistently(elephant_fn, model_cls, invalid) + + +@pytest.mark.parametrize("elephant_fn,model_cls", [ + (elephant.statistics.time_histogram, PydanticTimeHistogram), + (elephant.statistics.complexity_pdf, PydanticComplexityPdf), +]) +def test_valid_pq_quantity(elephant_fn, model_cls, make_spiketrains, make_pq_single_quantity): + valid = {"spiketrains": make_spiketrains, "bin_size": make_pq_single_quantity} + assert_both_succeed_consistently(elephant_fn, model_cls, valid) + + +@pytest.mark.parametrize("elephant_fn,model_cls", [ + (elephant.statistics.time_histogram, PydanticTimeHistogram), + (elephant.statistics.complexity_pdf, PydanticComplexityPdf), +]) +@pytest.mark.parametrize("pq_quantity", [ + 5, + "hello", + [0.01, 0.02] +]) +def test_invalid_pq_quantity(elephant_fn, model_cls, make_spiketrains, pq_quantity): + valid = {"spiketrains": make_spiketrains, "bin_size": pq_quantity} + assert_both_raise_consistently(elephant_fn, model_cls, valid) + + + +@pytest.mark.parametrize("elephant_fn,model_cls", [ + (elephant.statistics.instantaneous_rate, PydanticInstantaneousRate), +]) +@pytest.mark.parametrize("fixture", [ + "make_list", + "make_ndarray", + "make_pq_multiple_quantity", +], indirect=["fixture"]) +def test_invalid_spiketrains(elephant_fn, model_cls, fixture, make_pq_single_quantity): + invalid = {"spiketrains": fixture, "sampling_period": make_pq_single_quantity} + assert_both_raise_consistently(elephant_fn, model_cls, invalid) + +@pytest.mark.parametrize("output", [ + "counts", + "mean", + "rate", +]) +def test_valid_enum(output, make_spiketrains, make_pq_single_quantity): + valid = {"spiketrains": make_spiketrains, "bin_size": make_pq_single_quantity, "output": output} + assert_both_succeed_consistently(elephant.statistics.time_histogram, PydanticTimeHistogram, valid) + +@pytest.mark.parametrize("output", [ + "countsfagre", + 5, + "Counts", + "counts ", + " counts", + "counts\n" +]) +def test_invalid_enum(output, make_spiketrains, make_pq_single_quantity): + invalid = {"spiketrains": make_spiketrains, "bin_size": make_pq_single_quantity, "output": output} + assert_both_raise_consistently(elephant.statistics.time_histogram, PydanticTimeHistogram, invalid) + + +def test_valid_binned_spiketrain(make_binned_spiketrain): + valid = {"binned_spiketrain": make_binned_spiketrain} + assert_both_succeed_consistently( + elephant.spike_train_correlation.covariance, + PydanticCovariance, + valid + ) + +def test_invalid_binned_spiketrain(make_spiketrain): + invalid = {"binned_spiketrain": make_spiketrain} + assert_both_raise_consistently( + elephant.spike_train_correlation.covariance, + PydanticCovariance, + invalid, + ) + +@pytest.mark.parametrize("elephant_fn,model_cls,parameter_name,empty_input", [ + (elephant.statistics.instantaneous_rate, PydanticInstantaneousRate, "spiketrains", []), + (elephant.statistics.optimal_kernel_bandwidth, PydanticOptimalKernelBandwidth, "spiketimes", np.array([])), + (elephant.statistics.cv2, PydanticCv2, "time_intervals", np.array([])*pq.s), +]) +def test_invalid_empty_input(elephant_fn, model_cls, parameter_name, empty_input): + invalid = {parameter_name: empty_input} + assert_both_raise_consistently(elephant_fn, model_cls, invalid) + +@pytest.mark.parametrize("elephant_fn,model_cls,parameter_name,empty_input", [ + (elephant.spike_train_correlation.covariance, PydanticCovariance, "binned_spiketrain", elephant.conversion.BinnedSpikeTrain(neo.core.SpikeTrain(np.array([])*pq.s, t_start=0*pq.s, t_stop=1*pq.s), bin_size=0.01*pq.s)), +]) +def test_warning_empty_input(elephant_fn, model_cls, parameter_name, empty_input): + warning = {parameter_name: empty_input} + assert_both_warn_consistently(elephant_fn, model_cls, warning) + + +def test_valid_Complexity(make_spiketrains, make_pq_single_quantity): + valid = { "spiketrains": make_spiketrains, "bin_size": make_pq_single_quantity } + assert_both_succeed_consistently( + elephant.statistics.Complexity, + PydanticComplexityInit, + valid, + ) \ No newline at end of file From 7f4f5ef5cedb5af10f6bb5a1fc6136b585bc6a95 Mon Sep 17 00:00:00 2001 From: Jan Nolten Date: Fri, 7 Nov 2025 12:58:06 +0100 Subject: [PATCH 11/29] Transfering Bug fixes --- elephant/schemas/field_validator.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/elephant/schemas/field_validator.py b/elephant/schemas/field_validator.py index dbd01726e..6a61323c5 100644 --- a/elephant/schemas/field_validator.py +++ b/elephant/schemas/field_validator.py @@ -39,6 +39,8 @@ def get_length(obj) -> int: return obj.size elif isinstance(obj, (list,tuple)): return len(obj) + elif isinstance(obj, neo.core.spiketrainlist.SpikeTrainList): + return len(obj) @@ -115,6 +117,7 @@ def validate_type_length(value, info, allowed_types: tuple, allow_none: bool, mi def validate_array_content(value, info, allowed_types: tuple, allow_none: bool, min_length: int, allowed_content_types: tuple, min_length_content: int = 0): validate_type_length(value, info, allowed_types, allow_none, min_length) + hasContentLength = False for i, item in enumerate(value): if not isinstance(item, allowed_content_types): raise TypeError(f"Element {i} in {info.field_name} must be {allowed_content_types}, not {type(item).__name__}") @@ -144,7 +147,7 @@ def validate_spiketrains(value, info, allowed_types = (list,), allow_none = Fals def validate_spiketrains_matrix(value, info, allowed_types = (elephant.trials.Trials, list[neo.core.spiketrainlist.SpikeTrainList], list[list[neo.core.SpikeTrain]]), allow_none = False, min_length = 1, check_rank_deficient = False): if isinstance(value, list): - validate_spiketrains(value, info, allowed_content_types=(neo.core.spiketrainlist,list[neo.core.SpikeTrain],)) + validate_spiketrains(value, info, allowed_content_types=(neo.core.spiketrainlist.SpikeTrainList,list[neo.core.SpikeTrain],)) else: validate_type(value, info, (elephant.trials.Trials,), allow_none=False) if check_rank_deficient: @@ -169,7 +172,7 @@ def validate_time_intervals(value, info, allowed_types = (list, pq.Quantity, np. raise ValueError(f"{info.field_name} is not allowed to be a matrix") return value -def validate_array(value, info, allowed_types=(list, np.ndarray) , allow_none=False, min_length=1, allowed_content_types = None, min_length_content = 0): +def validate_array(value, info, allowed_types=(list, np.ndarray, tuple) , allow_none=False, min_length=1, allowed_content_types = None, min_length_content = 0): if allowed_content_types is None: validate_type_length(value, info, allowed_types, allow_none, min_length) else: @@ -202,10 +205,10 @@ def validate_key_in_tuple(value : str, info, t: tuple): # ---- Model validation helpers ---- -def model_validate_spiketrains_same_t_start_stop(spiketrain, t_start, t_stop, name: str = "spiketrains", warning: bool = False): +def model_validate_spiketrains_same_t_start_stop(spiketrains, t_start, t_stop, name: str = "spiketrains", warning: bool = False): if(t_start is None or t_stop is None): first = True - for i, item in enumerate(spiketrain): + for i, item in enumerate(spiketrains): if first: t_start = item.t_start t_stop = item.t_stop @@ -225,7 +228,7 @@ def model_validate_spiketrains_same_t_start_stop(spiketrain, t_start, t_stop, na if t_start>t_stop: raise ValueError(f"{name} has t_start > t_stop") -def model_validate_spiketrains_sam_t_start_stop(spiketrain_i, spiketrain_j): +def model_validate_two_spiketrains_same_t_start_stop(spiketrain_i, spiketrain_j): if spiketrain_i.t_start != spiketrain_j.t_start: raise ValueError("spiketrain_i and spiketrain_j need to have the same t_start") if spiketrain_i.t_stop != spiketrain_j.t_stop: From b282cc3b6bd90884cbd873d82e120fd5ca9a0f2a Mon Sep 17 00:00:00 2001 From: Jan Nolten Date: Fri, 7 Nov 2025 12:58:52 +0100 Subject: [PATCH 12/29] Transfering Bug fixes --- elephant/schemas/schema_statistics.py | 23 ++++++++++++++--------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/elephant/schemas/schema_statistics.py b/elephant/schemas/schema_statistics.py index 9314e2521..fa2b8c09d 100644 --- a/elephant/schemas/schema_statistics.py +++ b/elephant/schemas/schema_statistics.py @@ -64,7 +64,7 @@ class KernelOptions(Enum): spiketrains: Any = Field(..., description="Input spike train(s)") sampling_period: Any = Field(..., gt=0, description="Time stamp resolution of spike times") kernel: Union[KernelOptions, Any] = Field(KernelOptions.auto, description="Kernel for convolution") - cutoff: Optional[float] = Field(5.0, gt=0, description="cutoff of probability distribution") + cutoff: Optional[float] = Field(5.0, ge=0, description="cutoff of probability distribution") t_start: Optional[Any] = Field(None, ge=0, description="Start time") t_stop: Optional[Any] = Field(None, gt=0, description="Stop time") trim: Optional[bool] = Field(False, description="Only return region of convolved signal") @@ -77,9 +77,11 @@ class KernelOptions(Enum): @field_validator("spiketrains") @classmethod def validate_spiketrains(cls, v, info): - if(isinstance(v, list)): - return fv.validate_spiketrains(v, info, allowed_types=(list,), allowed_content_types=(neo.SpikeTrain,)) - return fv.validate_spiketrain(v, info, allowed_types=(neo.SpikeTrain, elephant.trials.Trials)) + if(isinstance(v, (list, neo.core.spiketrainlist.SpikeTrainList))): + return fv.validate_spiketrains(v, info, allowed_types=(list, neo.core.spiketrainlist.SpikeTrainList), allowed_content_types=(neo.SpikeTrain,)) + if(isinstance(v, neo.SpikeTrain)): + return fv.validate_spiketrain(v, info, allowed_types=(neo.SpikeTrain,)) + return fv.validate_spiketrains_matrix(v, info) @field_validator("sampling_period") @classmethod @@ -89,7 +91,9 @@ def validate_quantity(cls, v, info): @field_validator("kernel") @classmethod def validate_kernel(cls, v, info): - return fv.validate_type(v, info, allowed_types=(cls.KernelOptions, Kernel), allow_none=False) + if v == cls.KernelOptions.auto.value: + return v + return fv.validate_type(v, info, allowed_types=(Kernel), allow_none=False) @field_validator("t_start", "t_stop") @classmethod @@ -100,7 +104,8 @@ def validate_time(cls, v, info): def validate_model(self) -> Self: if(isinstance(self.kernel, Kernel) and self.cutoff < self.kernel.min_cutoff): warnings.warn(f"cutoff {self.cutoff} is smaller than the minimum cutoff {self.kernel.min_cutoff} of the kernel", UserWarning) - fv.model_validate_spiketrains_same_t_start_stop(self.spiketrains, self.t_start, self.t_stop, warning=True) + if isinstance(self.spiketrains, list): + fv.model_validate_spiketrains_same_t_start_stop(self.spiketrains, self.t_start, self.t_stop, warning=True) return self class PydanticTimeHistogram(BaseModel): @@ -185,13 +190,13 @@ class NanPolicyOptions(Enum): omit = "omit" _raise = "raise" - a: Any = Field(..., description="Input array") + args: Any = Field(..., description="Input array") axis: Union[int, None] = Field(0, description="Compute statistic axis") nan_policy: NanPolicyOptions = Field(NanPolicyOptions.propagate, description="How handle input NaNs") ddof: Optional[int] = Field(0, ge=0, description="Delta Degrees Of Freedom") keepdims: Optional[bool] = Field(False, description="leave reduced axes in one-dimensional result") - @field_validator("a") + @field_validator("args") @classmethod def validate_array(cls, v, info): return fv.validate_array(v, info) @@ -281,7 +286,7 @@ def serialize_quantity(self, v): @field_validator("spiketrains") @classmethod def validate_spiketrains(cls, v, info): - return fv.validate_spiketrains(v, info) + return fv.validate_spiketrains(v, info, min_length=0) @field_validator("warn_tolerance") @classmethod From 9c3402a9279a801f7b1c8ec4bde057819e81bdc3 Mon Sep 17 00:00:00 2001 From: Jan Nolten Date: Fri, 7 Nov 2025 12:59:17 +0100 Subject: [PATCH 13/29] Transfering Bug fixes --- elephant/schemas/schema_spike_train_correlation.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/elephant/schemas/schema_spike_train_correlation.py b/elephant/schemas/schema_spike_train_correlation.py index 46ca22611..aaa58db13 100644 --- a/elephant/schemas/schema_spike_train_correlation.py +++ b/elephant/schemas/schema_spike_train_correlation.py @@ -122,7 +122,7 @@ def validate_dt(cls, v, info): @model_validator(mode="after") def check_correctTypeCombination(self) -> Self: - fv.model_validate_spiketrains_sam_t_start_stop(self.spiketrain_i, self.spiketrain_j) + fv.model_validate_two_spiketrains_same_t_start_stop(self.spiketrain_i, self.spiketrain_j) return self @@ -147,6 +147,8 @@ def validate_max_tau(cls, v, info): @model_validator(mode="after") def check_correctTypeCombination(self) -> Self: - if self.max_tau % self.binned_spiketrain.bin_size != 0: + if self.max_tau % self.binned_spiketrain.bin_size > 0.00001: raise ValueError("max_tau has to be a multiple of binned_spiketrain.bin_size") - return self \ No newline at end of file + return self + + From ac00866ff9bad803334457ea9e903256a50fb66d Mon Sep 17 00:00:00 2001 From: Jan Nolten Date: Fri, 7 Nov 2025 13:08:25 +0100 Subject: [PATCH 14/29] Implemented validation for statistics --- elephant/statistics.py | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/elephant/statistics.py b/elephant/statistics.py index 45d9cd283..ed20adddb 100644 --- a/elephant/statistics.py +++ b/elephant/statistics.py @@ -83,6 +83,9 @@ from elephant.utils import deprecated_alias, check_neo_consistency, \ is_time_quantity, round_binning_errors +from elephant.schemas.function_validator import validate_with +from elephant.schemas.schema_statistics import *; + # do not import unicode_literals # (quantities rescale does not work with unicodes) @@ -102,9 +105,12 @@ "optimal_kernel_bandwidth" ] -cv = scipy.stats.variation +@validate_with(PydanticCv) +def cv(*args, **kwargs): + return scipy.stats.variation(*args, **kwargs) +@validate_with(PydanticIsi) def isi(spiketrain, axis=-1): """ Return an array containing the inter-spike intervals of the spike train. @@ -155,7 +161,7 @@ def isi(spiketrain, axis=-1): return intervals - +@validate_with(PydanticMeanFiringRate) def mean_firing_rate(spiketrain, t_start=None, t_stop=None, axis=None): """ Return the firing rate of the spike train. @@ -270,6 +276,7 @@ def mean_firing_rate(spiketrain, t_start=None, t_stop=None, axis=None): return rates +@validate_with(PydanticFanofactor) def fanofactor(spiketrains, warn_tolerance=0.1 * pq.ms): r""" Evaluates the empirical Fano factor F of the spike counts of @@ -373,6 +380,7 @@ def __variation_check(v, with_nan): return None +@validate_with(PydanticCv2) @deprecated_alias(v='time_intervals') def cv2(time_intervals, with_nan=False): r""" @@ -441,6 +449,7 @@ def cv2(time_intervals, with_nan=False): return 2. * np.mean(np.abs(cv_i)) +@validate_with(PydanticLv) @deprecated_alias(v='time_intervals') def lv(time_intervals, with_nan=False): r""" @@ -508,6 +517,7 @@ def lv(time_intervals, with_nan=False): return 3. * np.mean(np.power(cv_i, 2)) +@validate_with(PydanticLvr) def lvr(time_intervals, R=5*pq.ms, with_nan=False): r""" Calculate the measure of revised local variation LvR for a sequence of time @@ -600,6 +610,7 @@ def lvr(time_intervals, R=5*pq.ms, with_nan=False): return lvr +@validate_with(PydanticInstantaneousRate) @deprecated_alias(spiketrain='spiketrains') def instantaneous_rate(spiketrains, sampling_period, kernel='auto', cutoff=5.0, t_start=None, t_stop=None, trim=False, @@ -1061,6 +1072,7 @@ def optimal_kernel(st): return rate +@validate_with(PydanticTimeHistogram) @deprecated_alias(binsize='bin_size') def time_histogram(spiketrains, bin_size, t_start=None, t_stop=None, output='counts', binary=False): @@ -1204,6 +1216,7 @@ def _rate() -> pq.Quantity: normalization=output) +@validate_with(PydanticComplexityPdf) @deprecated_alias(binsize='bin_size') def complexity_pdf(spiketrains, bin_size): """ @@ -1418,6 +1431,7 @@ class Complexity(object): """ + @validate_with(PydanticComplexityInit) def __init__(self, spiketrains, sampling_rate=None, bin_size=None, @@ -1716,6 +1730,7 @@ def cost_function(x, N, w, dt): return C, yh +@validate_with(PydanticOptimalKernelBandwidth) @deprecated_alias(tin='times', w='bandwidth') def optimal_kernel_bandwidth(spiketimes, times=None, bandwidth=None, bootstrap=False): From a95f9f85fe63e7c9b8208781ddc60681b5b38e03 Mon Sep 17 00:00:00 2001 From: Jan Nolten Date: Fri, 7 Nov 2025 13:09:12 +0100 Subject: [PATCH 15/29] Implemented validation for spike_train_correlation --- elephant/spike_train_correlation.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/elephant/spike_train_correlation.py b/elephant/spike_train_correlation.py index 1d7cf0656..1e34eedb3 100644 --- a/elephant/spike_train_correlation.py +++ b/elephant/spike_train_correlation.py @@ -25,6 +25,9 @@ from scipy import integrate from elephant.utils import check_neo_consistency +from elephant.schemas.function_validator import validate_with +from elephant.schemas.schema_spike_train_correlation import *; + __all__ = [ "covariance", @@ -276,6 +279,7 @@ def kernel_smoothing(self, cross_corr_array, kernel): return np.convolve(cross_corr_array, kernel, mode='same') +@validate_with(PydanticCovariance) def covariance(binned_spiketrain, binary=False, fast=True): r""" Calculate the NxN matrix of pairwise covariances between all combinations @@ -376,6 +380,7 @@ def covariance(binned_spiketrain, binary=False, fast=True): binned_spiketrain, corrcoef_norm=False) +@validate_with(PydanticCorrelationCoefficient) def correlation_coefficient(binned_spiketrain, binary=False, fast=True): r""" Calculate the NxN matrix of pairwise Pearson's correlation coefficients @@ -549,6 +554,7 @@ def _covariance_sparse(binned_spiketrain, corrcoef_norm): return res +@validate_with(PydanticCrossCorrelationHistogram) def cross_correlation_histogram( binned_spiketrain_i, binned_spiketrain_j, window='full', border_correction=False, binary=False, kernel=None, method='speed', @@ -818,6 +824,7 @@ def cross_correlation_histogram( cch = cross_correlation_histogram +@validate_with(PydanticSpikeTimeTilingCoefficient) def spike_time_tiling_coefficient(spiketrain_i: neo.core.SpikeTrain, spiketrain_j: neo.core.SpikeTrain, dt: pq.Quantity = 0.005 * pq.s) -> float: @@ -992,6 +999,7 @@ def run_t(spiketrain: neo.core.SpikeTrain, dt: pq.Quantity = dt) -> float: sttc = spike_time_tiling_coefficient +@validate_with(PydanticSpikeTrainTimescale) def spike_train_timescale(binned_spiketrain, max_tau): r""" Calculates the auto-correlation time of a binned spike train; uses the From 69745c2e1bc83c6dd029a24098796f952aedd672 Mon Sep 17 00:00:00 2001 From: Jan Nolten Date: Fri, 7 Nov 2025 13:09:51 +0100 Subject: [PATCH 16/29] Implemented validation for spike_train_synchrony --- elephant/spike_train_synchrony.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/elephant/spike_train_synchrony.py b/elephant/spike_train_synchrony.py index 946a24ae2..84fbe5406 100644 --- a/elephant/spike_train_synchrony.py +++ b/elephant/spike_train_synchrony.py @@ -29,6 +29,9 @@ from elephant.statistics import Complexity from elephant.utils import is_time_quantity, check_same_units +from elephant.schemas.function_validator import validate_with +from elephant.schemas.schema_spike_train_synchrony import *; + SpikeContrastTrace = namedtuple("SpikeContrastTrace", ( "contrast", "active_spiketrains", "synchrony", "bin_size")) @@ -69,6 +72,7 @@ def _binning_half_overlap(spiketrain, edges): return histogram +@validate_with(PydanticSpikeContrast) def spike_contrast(spiketrains, t_start=None, t_stop=None, min_bin=10 * pq.ms, bin_shrink_factor=0.9, return_trace=False): @@ -261,6 +265,7 @@ class Synchrotool(Complexity): """ + @validate_with(PydanticSynchrotoolInit) def __init__(self, spiketrains, sampling_rate, bin_size=None, @@ -277,6 +282,7 @@ def __init__(self, spiketrains, spread=spread, tolerance=tolerance) + @validate_with(PydanticSynchrotoolDeleteSynchrofacts) def delete_synchrofacts(self, threshold, in_place=False, mode='delete'): """ Delete or extract synchronous spiking events. From 460e9cd5eebd92a446cb4df56d6a454fb77aebc3 Mon Sep 17 00:00:00 2001 From: Jan Nolten Date: Fri, 7 Nov 2025 13:25:49 +0100 Subject: [PATCH 17/29] Allowed some ValueErrors to also be TypeErrors --- elephant/test/test_spike_train_correlation.py | 2 +- elephant/test/test_statistics.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/elephant/test/test_spike_train_correlation.py b/elephant/test/test_spike_train_correlation.py index 90de65ea1..75682332b 100644 --- a/elephant/test/test_spike_train_correlation.py +++ b/elephant/test/test_spike_train_correlation.py @@ -913,7 +913,7 @@ def test_timescale_errors(self): # Tau max with no units tau_max = 1 - self.assertRaises(ValueError, + self.assertRaises((ValueError, TypeError), sc.spike_train_timescale, spikes_bin, tau_max) # Tau max that is not a multiple of the binsize diff --git a/elephant/test/test_statistics.py b/elephant/test/test_statistics.py index 426111810..a94c4aa68 100644 --- a/elephant/test/test_statistics.py +++ b/elephant/test/test_statistics.py @@ -383,7 +383,7 @@ def test_lv_with_list(self): def test_lv_raise_error(self): seq = self.test_seq self.assertRaises(ValueError, statistics.lv, []) - self.assertRaises(ValueError, statistics.lv, 1) + self.assertRaises((ValueError, TypeError), statistics.lv, 1) self.assertRaises(ValueError, statistics.lv, np.array([seq, seq])) def test_2short_spike_train(self): @@ -430,7 +430,7 @@ def test_lvr_with_list(self): def test_lvr_raise_error(self): seq = self.test_seq self.assertRaises(ValueError, statistics.lvr, []) - self.assertRaises(ValueError, statistics.lvr, 1) + self.assertRaises((ValueError, TypeError), statistics.lvr, 1) self.assertRaises(ValueError, statistics.lvr, np.array([seq, seq])) self.assertRaises(ValueError, statistics.lvr, seq, -1 * pq.ms) @@ -478,7 +478,7 @@ def test_cv2_with_list(self): def test_cv2_raise_error(self): seq = self.test_seq self.assertRaises(ValueError, statistics.cv2, []) - self.assertRaises(ValueError, statistics.cv2, 1) + self.assertRaises((ValueError, TypeError), statistics.cv2, 1) self.assertRaises(ValueError, statistics.cv2, np.array([seq, seq])) From fac02e1b568cae99160cfa59a6f0ad0a8e533458 Mon Sep 17 00:00:00 2001 From: Jan Nolten Date: Mon, 10 Nov 2025 08:57:44 +0100 Subject: [PATCH 18/29] Removed ; at end of lines --- elephant/spike_train_correlation.py | 2 +- elephant/spike_train_synchrony.py | 2 +- elephant/statistics.py | 2 +- elephant/test/test_schemas.py | 6 +++--- 4 files changed, 6 insertions(+), 6 deletions(-) diff --git a/elephant/spike_train_correlation.py b/elephant/spike_train_correlation.py index 1e34eedb3..6bb03641f 100644 --- a/elephant/spike_train_correlation.py +++ b/elephant/spike_train_correlation.py @@ -26,7 +26,7 @@ from elephant.utils import check_neo_consistency from elephant.schemas.function_validator import validate_with -from elephant.schemas.schema_spike_train_correlation import *; +from elephant.schemas.schema_spike_train_correlation import * __all__ = [ diff --git a/elephant/spike_train_synchrony.py b/elephant/spike_train_synchrony.py index 84fbe5406..0e27cb4f7 100644 --- a/elephant/spike_train_synchrony.py +++ b/elephant/spike_train_synchrony.py @@ -30,7 +30,7 @@ from elephant.utils import is_time_quantity, check_same_units from elephant.schemas.function_validator import validate_with -from elephant.schemas.schema_spike_train_synchrony import *; +from elephant.schemas.schema_spike_train_synchrony import * SpikeContrastTrace = namedtuple("SpikeContrastTrace", ( "contrast", "active_spiketrains", "synchrony", "bin_size")) diff --git a/elephant/statistics.py b/elephant/statistics.py index ed20adddb..b6fd9deb5 100644 --- a/elephant/statistics.py +++ b/elephant/statistics.py @@ -84,7 +84,7 @@ is_time_quantity, round_binning_errors from elephant.schemas.function_validator import validate_with -from elephant.schemas.schema_statistics import *; +from elephant.schemas.schema_statistics import * # do not import unicode_literals # (quantities rescale does not work with unicodes) diff --git a/elephant/test/test_schemas.py b/elephant/test/test_schemas.py index 7297a6b1c..f08b22d1b 100644 --- a/elephant/test/test_schemas.py +++ b/elephant/test/test_schemas.py @@ -8,9 +8,9 @@ from pydantic import ValidationError -from elephant.schemas.schema_statistics import *; -from elephant.schemas.schema_spike_train_correlation import *; -from elephant.schemas.schema_spike_train_synchrony import *; +from elephant.schemas.schema_statistics import * +from elephant.schemas.schema_spike_train_correlation import * +from elephant.schemas.schema_spike_train_synchrony import * def test_model_json_schema(): From 4fdd64ad366e1c76fb1fe6c84b7d8244e304f098 Mon Sep 17 00:00:00 2001 From: Jan Nolten Date: Mon, 10 Nov 2025 13:27:58 +0100 Subject: [PATCH 19/29] Added Pydantic to requirements --- requirements/requirements.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements/requirements.txt b/requirements/requirements.txt index b3b9d6f98..3021ae5f9 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -4,3 +4,4 @@ quantities>=0.14.1 scipy>=1.10.0 six>=1.10.0 tqdm +pydantic>=2.0.0 \ No newline at end of file From 7d933a4a8394230c0f1adc4d448769eff42ce7b2 Mon Sep 17 00:00:00 2001 From: Jan Nolten Date: Tue, 11 Nov 2025 15:43:29 +0100 Subject: [PATCH 20/29] Removed Self from typing, because it only works in python>=3.11.0 --- elephant/schemas/schema_spike_train_correlation.py | 9 ++++----- elephant/schemas/schema_statistics.py | 13 ++++++------- 2 files changed, 10 insertions(+), 12 deletions(-) diff --git a/elephant/schemas/schema_spike_train_correlation.py b/elephant/schemas/schema_spike_train_correlation.py index aaa58db13..a2153cd0e 100644 --- a/elephant/schemas/schema_spike_train_correlation.py +++ b/elephant/schemas/schema_spike_train_correlation.py @@ -3,7 +3,6 @@ from typing import ( Any, Union, - Self, Optional ) from pydantic import ( @@ -35,7 +34,7 @@ def validate_binned_spiketrain(cls, v, info): return fv.validate_binned_spiketrain(v, info) @model_validator(mode="after") - def validate_model(self) -> Self: + def validate_model(self): fv.model_validate_binned_spiketrain_fast(self.binned_spiketrain, self.fast) return self @@ -56,7 +55,7 @@ def validate_binned_spiketrain(cls, v, info): return fv.validate_binned_spiketrain(v, info) @model_validator(mode="after") - def validate_model(self) -> Self: + def validate_model(self): fv.model_validate_binned_spiketrain_fast(self.binned_spiketrain, self.fast) return self @@ -121,7 +120,7 @@ def validate_dt(cls, v, info): return fv.validate_quantity(v, info) @model_validator(mode="after") - def check_correctTypeCombination(self) -> Self: + def check_correctTypeCombination(self): fv.model_validate_two_spiketrains_same_t_start_stop(self.spiketrain_i, self.spiketrain_j) return self @@ -146,7 +145,7 @@ def validate_max_tau(cls, v, info): return fv.validate_quantity(v, info) @model_validator(mode="after") - def check_correctTypeCombination(self) -> Self: + def check_correctTypeCombination(self): if self.max_tau % self.binned_spiketrain.bin_size > 0.00001: raise ValueError("max_tau has to be a multiple of binned_spiketrain.bin_size") return self diff --git a/elephant/schemas/schema_statistics.py b/elephant/schemas/schema_statistics.py index fa2b8c09d..934d7b28a 100644 --- a/elephant/schemas/schema_statistics.py +++ b/elephant/schemas/schema_statistics.py @@ -3,7 +3,6 @@ from typing import ( Any, Union, - Self, Optional ) from pydantic import ( @@ -44,7 +43,7 @@ def validate_time(cls, v, info): return fv.validate_time(v, info) @model_validator(mode="after") - def validate_model(self) -> Self: + def validate_model(self): if isinstance(self.spiketrain, (np.ndarray, list)): if isinstance(self.t_start, pq.Quantity) or isinstance(self.t_stop, pq.Quantity): raise TypeError("spiketrain is a np.ndarray or list but t_start or t_stop is pq.Quantity") @@ -101,7 +100,7 @@ def validate_time(cls, v, info): return fv.validate_quantity(v, info, allow_none=True) @model_validator(mode="after") - def validate_model(self) -> Self: + def validate_model(self): if(isinstance(self.kernel, Kernel) and self.cutoff < self.kernel.min_cutoff): warnings.warn(f"cutoff {self.cutoff} is smaller than the minimum cutoff {self.kernel.min_cutoff} of the kernel", UserWarning) if isinstance(self.spiketrains, list): @@ -142,7 +141,7 @@ def validate_quantity_none(cls, v, info): return fv.validate_quantity(v, info, allow_none=True) @model_validator(mode="after") - def validate_model(self) -> Self: + def validate_model(self): fv.model_validate_spiketrains_same_t_start_stop(self.spiketrains, self.t_start, self.t_stop, warning=True) return self @@ -216,7 +215,7 @@ def validate_time_intervals(cls, v, info): return fv.validate_time_intervals(v, info, check_matrix=True) @model_validator(mode="after") - def validate_model(self) -> Self: + def validate_model(self): fv.model_validate_time_intervals_with_nan(self.time_intervals, self.with_nan) return self @@ -235,7 +234,7 @@ def validate_time_intervals(cls, v, info): return fv.validate_time_intervals(v, info, check_matrix=True) @model_validator(mode="after") - def validate_model(self) -> Self: + def validate_model(self): fv.model_validate_time_intervals_with_nan(self.time_intervals, self.with_nan) return self @@ -267,7 +266,7 @@ def validate_R(cls, v, info): return v @model_validator(mode="after") - def validate_model(self) -> Self: + def validate_model(self): fv.model_validate_time_intervals_with_nan(self.time_intervals, self.with_nan) return self From 932f1d49347b6eba6f13f8c391ad2fca08b79f93 Mon Sep 17 00:00:00 2001 From: Jan Nolten Date: Tue, 11 Nov 2025 15:44:38 +0100 Subject: [PATCH 21/29] Added ability to disable validation globally --- elephant/schemas/function_validator.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/elephant/schemas/function_validator.py b/elephant/schemas/function_validator.py index c19304685..bf8219f0a 100644 --- a/elephant/schemas/function_validator.py +++ b/elephant/schemas/function_validator.py @@ -2,6 +2,8 @@ from inspect import signature from pydantic import BaseModel +skip_validation = False + def validate_with(model_class: type[BaseModel]): """ A decorator that validates the inputs of a function using a Pydantic model. @@ -13,7 +15,7 @@ def decorator(func): @wraps(func) def wrapper(*args, **kwargs): - if kwargs.pop("not_validate", False): + if kwargs.pop("not_validate", False) or skip_validation: # skip validation, call inner function directly return func(*args, **kwargs) @@ -29,4 +31,10 @@ def wrapper(*args, **kwargs): return func(*args, **kwargs) wrapper._is_validate_with = True return wrapper - return decorator \ No newline at end of file + return decorator + +def activate_validation(): + skip_validation = False + +def deactivate_validation(): + skip_validation = True \ No newline at end of file From 1f58b120731e20af095feb315142b411df34faf6 Mon Sep 17 00:00:00 2001 From: Jan Nolten Date: Tue, 11 Nov 2025 16:12:58 +0100 Subject: [PATCH 22/29] Allow t_start to be negative because it should be able to be used that way --- elephant/schemas/schema_statistics.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/elephant/schemas/schema_statistics.py b/elephant/schemas/schema_statistics.py index 934d7b28a..165bc8cd6 100644 --- a/elephant/schemas/schema_statistics.py +++ b/elephant/schemas/schema_statistics.py @@ -64,7 +64,7 @@ class KernelOptions(Enum): sampling_period: Any = Field(..., gt=0, description="Time stamp resolution of spike times") kernel: Union[KernelOptions, Any] = Field(KernelOptions.auto, description="Kernel for convolution") cutoff: Optional[float] = Field(5.0, ge=0, description="cutoff of probability distribution") - t_start: Optional[Any] = Field(None, ge=0, description="Start time") + t_start: Optional[Any] = Field(None, description="Start time") t_stop: Optional[Any] = Field(None, gt=0, description="Stop time") trim: Optional[bool] = Field(False, description="Only return region of convolved signal") center_kernel: Optional[bool] = Field(True, description="Center the kernel on spike") From 6f2b7d3eb2e9ec0a2e69ca1bccd3ffa09d33b2cb Mon Sep 17 00:00:00 2001 From: Jan Nolten Date: Tue, 18 Nov 2025 16:31:44 +0100 Subject: [PATCH 23/29] Allowed all t_start and t_stop to be negative, becuase they could be if seen relativly --- elephant/schemas/schema_statistics.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/elephant/schemas/schema_statistics.py b/elephant/schemas/schema_statistics.py index 165bc8cd6..9d79644da 100644 --- a/elephant/schemas/schema_statistics.py +++ b/elephant/schemas/schema_statistics.py @@ -65,7 +65,7 @@ class KernelOptions(Enum): kernel: Union[KernelOptions, Any] = Field(KernelOptions.auto, description="Kernel for convolution") cutoff: Optional[float] = Field(5.0, ge=0, description="cutoff of probability distribution") t_start: Optional[Any] = Field(None, description="Start time") - t_stop: Optional[Any] = Field(None, gt=0, description="Stop time") + t_stop: Optional[Any] = Field(None, description="Stop time") trim: Optional[bool] = Field(False, description="Only return region of convolved signal") center_kernel: Optional[bool] = Field(True, description="Center the kernel on spike") border_correction: Optional[bool] = Field(False, description="Apply border correction") From c46228df38394903f4766114ffd5f4f6669d2aa6 Mon Sep 17 00:00:00 2001 From: Jan Nolten Date: Mon, 24 Nov 2025 13:25:13 +0100 Subject: [PATCH 24/29] Removed the option to skip validation with the extra kwargs not_validate, so the api does not need to be changed and to simplify the decorator --- elephant/schemas/function_validator.py | 18 +++++++----------- 1 file changed, 7 insertions(+), 11 deletions(-) diff --git a/elephant/schemas/function_validator.py b/elephant/schemas/function_validator.py index bf8219f0a..9eaf163bd 100644 --- a/elephant/schemas/function_validator.py +++ b/elephant/schemas/function_validator.py @@ -15,21 +15,17 @@ def decorator(func): @wraps(func) def wrapper(*args, **kwargs): - if kwargs.pop("not_validate", False) or skip_validation: - # skip validation, call inner function directly - return func(*args, **kwargs) + if not skip_validation: + # Bind args & kwargs to function parameters + bound = sig.bind_partial(*args, **kwargs) + bound.apply_defaults() + data = bound.arguments - # Bind args & kwargs to function parameters - bound = sig.bind_partial(*args, **kwargs) - bound.apply_defaults() - data = bound.arguments - - # Validate using Pydantic - model_class(**data) + # Validate using Pydantic + model_class(**data) # Call function return func(*args, **kwargs) - wrapper._is_validate_with = True return wrapper return decorator From ae94ee687acc9e820522abfee8a78e88a4bd77f7 Mon Sep 17 00:00:00 2001 From: Jan Nolten Date: Mon, 24 Nov 2025 14:03:56 +0100 Subject: [PATCH 25/29] Simplified test to make them more understandable --- elephant/test/test_schemas.py | 159 +++++++++++----------------------- 1 file changed, 52 insertions(+), 107 deletions(-) diff --git a/elephant/test/test_schemas.py b/elephant/test/test_schemas.py index f08b22d1b..e313cadc5 100644 --- a/elephant/test/test_schemas.py +++ b/elephant/test/test_schemas.py @@ -6,7 +6,7 @@ import elephant -from pydantic import ValidationError +from elephant.schemas.function_validator import deactivate_validation from elephant.schemas.schema_statistics import * from elephant.schemas.schema_spike_train_correlation import * @@ -32,100 +32,26 @@ def test_model_json_schema(): PydanticLvr, PydanticFanofactor, PydanticComplexityPdf, - PydanticComplexityInit, PydanticSpikeContrast, + PydanticComplexityInit, PydanticSynchrotoolInit, - PydanticSynchrotoolDeleteSynchrofacts, + PydanticSynchrotoolDeleteSynchrofacts ] for cls in model_classes: schema = cls.model_json_schema() assert isinstance(schema, dict) + """ Checking for consistent behavior between Elephant functions and Pydantic models. Tests bypass validate_with decorator if it is already implemented for that function so consistency is checked correctly """ -def call_elephant_function(elephant_fn, kwargs): - if hasattr(elephant_fn, "_is_validate_with"): - kwargs["not_validate"]=True - elephant_fn(**kwargs) - else: - elephant_fn(**kwargs) - -def assert_both_succeed_consistently(elephant_fn, model_cls, kwargs): - """Call both the Elephant function and the Pydantic model with the same kwargs. - Assert both complete without raising exceptions. - - Parameters - - elephant_fn: callable to invoke with kwargs - - model_cls: Pydantic model class to instantiate with kwargs - - kwargs: dict of keyword arguments to pass to both - """ - try: - call_elephant_function(elephant_fn, kwargs) - except Exception as e: - assert False, f"Elephant function raised an exception: {e}" - - try: - model_cls(**kwargs) - except Exception as e: - assert False, f"Pydantic model raised an exception: {e}" - -def assert_both_warn_consistently(elephant_fn, model_cls, kwargs): - """Call both the Elephant function and the Pydantic model with the same kwargs. - Assert both raise warnings. - - Parameters - - elephant_fn: callable to invoke with kwargs - - model_cls: Pydantic model class to instantiate with kwargs - - kwargs: dict of keyword arguments to pass to both - """ - with pytest.warns(Warning) as w1: - call_elephant_function(elephant_fn, kwargs) - with pytest.warns(Warning) as w2: - model_cls(**kwargs) - - -def assert_both_raise_consistently(elephant_fn, model_cls, kwargs, *, same_type=False, expected_exception=None): - """Call both the Elephant function and the Pydantic model with the same kwargs. - Assert both raise, and if requested assert they raise the same exception type. - - Uses pytest.raises to capture exceptions so failures are reported with pytest's - native formatting while still allowing comparison of exception objects. - - Parameters - - elephant_fn: callable to invoke with kwargs - - model_cls: Pydantic model class to instantiate with kwargs - - kwargs: dict of keyword arguments to pass to both - - same_type: if True assert the raised exception classes are identical - - expected_exception: optional exception type that both must be instances of - """ - with pytest.raises(Exception) as e1: - call_elephant_function(elephant_fn, kwargs) - with pytest.raises(Exception) as e2: - model_cls(**kwargs) - - exc1 = e1.value - exc2 = e2.value - - if expected_exception is not None: - assert isinstance(exc1, expected_exception), ( - f"Elephant raised {type(exc1)}, expected {expected_exception}") - assert isinstance(exc2, expected_exception), ( - f"Pydantic raised {type(exc2)}, expected {expected_exception}") - - if same_type: - if(type(exc1) is type(exc2)): - return - - if (isinstance(exc1, (ValueError, TypeError)) and isinstance(exc2, (ValidationError, AttributeError))): - return - - assert False, ( - f"Different exception types: Elephant={type(exc1)}, Pydantic={type(exc2)}. " - f"Elephant exc: {exc1}; Pydantic exc: {exc2}") +# Deactivate validation happening in the decorator of the elephant functions for all tests in this module to keep checking consistent behavior +@pytest.fixture(autouse=True) +def disable_validation_for_tests(): + deactivate_validation() @pytest.fixture def make_list(): @@ -178,7 +104,9 @@ def fixture(request): ], indirect=["fixture"]) def test_valid_spiketrain_input(elephant_fn, model_cls, fixture): valid = {"spiketrain": fixture} - assert_both_succeed_consistently(elephant_fn, model_cls, valid) + assert(isinstance(model_cls(**valid), model_cls)) + # just check it runs without error + elephant_fn(**valid) @pytest.mark.parametrize("elephant_fn,model_cls", [ @@ -191,7 +119,10 @@ def test_valid_spiketrain_input(elephant_fn, model_cls, fixture): ]) def test_invalid_spiketrain(elephant_fn, model_cls, spiketrain): invalid = {"spiketrain": spiketrain} - assert_both_raise_consistently(elephant_fn, model_cls, invalid) + with pytest.raises(Exception): + model_cls(**invalid) + with pytest.raises(Exception): + elephant_fn(**invalid) @pytest.mark.parametrize("elephant_fn,model_cls", [ @@ -200,7 +131,9 @@ def test_invalid_spiketrain(elephant_fn, model_cls, spiketrain): ]) def test_valid_pq_quantity(elephant_fn, model_cls, make_spiketrains, make_pq_single_quantity): valid = {"spiketrains": make_spiketrains, "bin_size": make_pq_single_quantity} - assert_both_succeed_consistently(elephant_fn, model_cls, valid) + assert(isinstance(model_cls(**valid), model_cls)) + # just check it runs without error + elephant_fn(**valid) @pytest.mark.parametrize("elephant_fn,model_cls", [ @@ -213,8 +146,11 @@ def test_valid_pq_quantity(elephant_fn, model_cls, make_spiketrains, make_pq_sin [0.01, 0.02] ]) def test_invalid_pq_quantity(elephant_fn, model_cls, make_spiketrains, pq_quantity): - valid = {"spiketrains": make_spiketrains, "bin_size": pq_quantity} - assert_both_raise_consistently(elephant_fn, model_cls, valid) + invalid = {"spiketrains": make_spiketrains, "bin_size": pq_quantity} + with pytest.raises(Exception): + model_cls(**invalid) + with pytest.raises(Exception): + elephant_fn(**invalid) @@ -228,7 +164,10 @@ def test_invalid_pq_quantity(elephant_fn, model_cls, make_spiketrains, pq_quanti ], indirect=["fixture"]) def test_invalid_spiketrains(elephant_fn, model_cls, fixture, make_pq_single_quantity): invalid = {"spiketrains": fixture, "sampling_period": make_pq_single_quantity} - assert_both_raise_consistently(elephant_fn, model_cls, invalid) + with pytest.raises(Exception): + model_cls(**invalid) + with pytest.raises(Exception): + elephant_fn(**invalid) @pytest.mark.parametrize("output", [ "counts", @@ -237,7 +176,9 @@ def test_invalid_spiketrains(elephant_fn, model_cls, fixture, make_pq_single_qua ]) def test_valid_enum(output, make_spiketrains, make_pq_single_quantity): valid = {"spiketrains": make_spiketrains, "bin_size": make_pq_single_quantity, "output": output} - assert_both_succeed_consistently(elephant.statistics.time_histogram, PydanticTimeHistogram, valid) + assert(isinstance(PydanticTimeHistogram(**valid), PydanticTimeHistogram)) + # just check it runs without error + elephant.statistics.time_histogram(**valid) @pytest.mark.parametrize("output", [ "countsfagre", @@ -249,24 +190,24 @@ def test_valid_enum(output, make_spiketrains, make_pq_single_quantity): ]) def test_invalid_enum(output, make_spiketrains, make_pq_single_quantity): invalid = {"spiketrains": make_spiketrains, "bin_size": make_pq_single_quantity, "output": output} - assert_both_raise_consistently(elephant.statistics.time_histogram, PydanticTimeHistogram, invalid) + with pytest.raises(Exception): + PydanticTimeHistogram(**invalid) + with pytest.raises(Exception): + elephant.statistics.time_histogram(**invalid) def test_valid_binned_spiketrain(make_binned_spiketrain): valid = {"binned_spiketrain": make_binned_spiketrain} - assert_both_succeed_consistently( - elephant.spike_train_correlation.covariance, - PydanticCovariance, - valid - ) + assert(isinstance(PydanticCovariance(**valid), PydanticCovariance)) + # just check it runs without error + elephant.spike_train_correlation.covariance(**valid) def test_invalid_binned_spiketrain(make_spiketrain): invalid = {"binned_spiketrain": make_spiketrain} - assert_both_raise_consistently( - elephant.spike_train_correlation.covariance, - PydanticCovariance, - invalid, - ) + with pytest.raises(Exception): + PydanticCovariance(**invalid) + with pytest.raises(Exception): + elephant.spike_train_correlation.covariance(**invalid) @pytest.mark.parametrize("elephant_fn,model_cls,parameter_name,empty_input", [ (elephant.statistics.instantaneous_rate, PydanticInstantaneousRate, "spiketrains", []), @@ -275,20 +216,24 @@ def test_invalid_binned_spiketrain(make_spiketrain): ]) def test_invalid_empty_input(elephant_fn, model_cls, parameter_name, empty_input): invalid = {parameter_name: empty_input} - assert_both_raise_consistently(elephant_fn, model_cls, invalid) + with pytest.raises(Exception): + model_cls(**invalid) + with pytest.raises(Exception): + elephant_fn(**invalid) @pytest.mark.parametrize("elephant_fn,model_cls,parameter_name,empty_input", [ (elephant.spike_train_correlation.covariance, PydanticCovariance, "binned_spiketrain", elephant.conversion.BinnedSpikeTrain(neo.core.SpikeTrain(np.array([])*pq.s, t_start=0*pq.s, t_stop=1*pq.s), bin_size=0.01*pq.s)), ]) def test_warning_empty_input(elephant_fn, model_cls, parameter_name, empty_input): warning = {parameter_name: empty_input} - assert_both_warn_consistently(elephant_fn, model_cls, warning) + with pytest.warns(Warning): + model_cls(**warning) + with pytest.warns(Warning): + elephant_fn(**warning) def test_valid_Complexity(make_spiketrains, make_pq_single_quantity): valid = { "spiketrains": make_spiketrains, "bin_size": make_pq_single_quantity } - assert_both_succeed_consistently( - elephant.statistics.Complexity, - PydanticComplexityInit, - valid, - ) \ No newline at end of file + assert(isinstance(PydanticComplexityInit(**valid), PydanticComplexityInit)) + # just check it runs without error + elephant.statistics.Complexity(**valid) \ No newline at end of file From ca247c0cc3492dd79ead61fef263fe43f0993fe3 Mon Sep 17 00:00:00 2001 From: Jan Nolten Date: Mon, 24 Nov 2025 16:11:43 +0100 Subject: [PATCH 26/29] Make test stricter by checking for the exact Error Type. Also Fixed Bugs --- elephant/schemas/function_validator.py | 3 ++ elephant/test/test_schemas.py | 56 ++++++++++++++------------ 2 files changed, 34 insertions(+), 25 deletions(-) diff --git a/elephant/schemas/function_validator.py b/elephant/schemas/function_validator.py index 9eaf163bd..4300d0aa4 100644 --- a/elephant/schemas/function_validator.py +++ b/elephant/schemas/function_validator.py @@ -15,6 +15,7 @@ def decorator(func): @wraps(func) def wrapper(*args, **kwargs): + print(skip_validation) if not skip_validation: # Bind args & kwargs to function parameters bound = sig.bind_partial(*args, **kwargs) @@ -30,7 +31,9 @@ def wrapper(*args, **kwargs): return decorator def activate_validation(): + global skip_validation skip_validation = False def deactivate_validation(): + global skip_validation skip_validation = True \ No newline at end of file diff --git a/elephant/test/test_schemas.py b/elephant/test/test_schemas.py index e313cadc5..3db6b8162 100644 --- a/elephant/test/test_schemas.py +++ b/elephant/test/test_schemas.py @@ -6,7 +6,8 @@ import elephant -from elephant.schemas.function_validator import deactivate_validation +from pydantic import ValidationError +from elephant.schemas.function_validator import deactivate_validation, activate_validation from elephant.schemas.schema_statistics import * from elephant.schemas.schema_spike_train_correlation import * @@ -48,10 +49,15 @@ def test_model_json_schema(): so consistency is checked correctly """ -# Deactivate validation happening in the decorator of the elephant functions for all tests in this module to keep checking consistent behavior -@pytest.fixture(autouse=True) -def disable_validation_for_tests(): - deactivate_validation() +# Deactivate validation happening in the decorator of the elephant functions before all tests in this module to keep checking consistent behavior. Activates it again after all tests in this module have run. + +@pytest.fixture(scope="module", autouse=True) +def module_setup_teardown(): + deactivate_validation() + + yield + + activate_validation() @pytest.fixture def make_list(): @@ -119,9 +125,9 @@ def test_valid_spiketrain_input(elephant_fn, model_cls, fixture): ]) def test_invalid_spiketrain(elephant_fn, model_cls, spiketrain): invalid = {"spiketrain": spiketrain} - with pytest.raises(Exception): + with pytest.raises(TypeError): model_cls(**invalid) - with pytest.raises(Exception): + with pytest.raises((TypeError, ValueError)): elephant_fn(**invalid) @@ -147,9 +153,9 @@ def test_valid_pq_quantity(elephant_fn, model_cls, make_spiketrains, make_pq_sin ]) def test_invalid_pq_quantity(elephant_fn, model_cls, make_spiketrains, pq_quantity): invalid = {"spiketrains": make_spiketrains, "bin_size": pq_quantity} - with pytest.raises(Exception): + with pytest.raises(TypeError): model_cls(**invalid) - with pytest.raises(Exception): + with pytest.raises(AttributeError): elephant_fn(**invalid) @@ -164,9 +170,9 @@ def test_invalid_pq_quantity(elephant_fn, model_cls, make_spiketrains, pq_quanti ], indirect=["fixture"]) def test_invalid_spiketrains(elephant_fn, model_cls, fixture, make_pq_single_quantity): invalid = {"spiketrains": fixture, "sampling_period": make_pq_single_quantity} - with pytest.raises(Exception): + with pytest.raises(TypeError): model_cls(**invalid) - with pytest.raises(Exception): + with pytest.raises(TypeError): elephant_fn(**invalid) @pytest.mark.parametrize("output", [ @@ -190,9 +196,9 @@ def test_valid_enum(output, make_spiketrains, make_pq_single_quantity): ]) def test_invalid_enum(output, make_spiketrains, make_pq_single_quantity): invalid = {"spiketrains": make_spiketrains, "bin_size": make_pq_single_quantity, "output": output} - with pytest.raises(Exception): + with pytest.raises(ValidationError): PydanticTimeHistogram(**invalid) - with pytest.raises(Exception): + with pytest.raises(ValueError): elephant.statistics.time_histogram(**invalid) @@ -204,21 +210,21 @@ def test_valid_binned_spiketrain(make_binned_spiketrain): def test_invalid_binned_spiketrain(make_spiketrain): invalid = {"binned_spiketrain": make_spiketrain} - with pytest.raises(Exception): + with pytest.raises(TypeError): PydanticCovariance(**invalid) - with pytest.raises(Exception): + with pytest.raises(AttributeError): elephant.spike_train_correlation.covariance(**invalid) -@pytest.mark.parametrize("elephant_fn,model_cls,parameter_name,empty_input", [ - (elephant.statistics.instantaneous_rate, PydanticInstantaneousRate, "spiketrains", []), - (elephant.statistics.optimal_kernel_bandwidth, PydanticOptimalKernelBandwidth, "spiketimes", np.array([])), - (elephant.statistics.cv2, PydanticCv2, "time_intervals", np.array([])*pq.s), +@pytest.mark.parametrize("elephant_fn,model_cls,invalid", [ + (elephant.statistics.instantaneous_rate, PydanticInstantaneousRate, {"spiketrains": [], "sampling_period": 0.01 * pq.s}), + (elephant.statistics.optimal_kernel_bandwidth, PydanticOptimalKernelBandwidth, {"spiketimes": np.array([])}), + (elephant.statistics.cv2, PydanticCv2, {"time_intervals": np.array([])*pq.s}), ]) -def test_invalid_empty_input(elephant_fn, model_cls, parameter_name, empty_input): - invalid = {parameter_name: empty_input} - with pytest.raises(Exception): +def test_invalid_empty_input(elephant_fn, model_cls, invalid): + + with pytest.raises(ValueError): model_cls(**invalid) - with pytest.raises(Exception): + with pytest.raises((ValueError,TypeError)): elephant_fn(**invalid) @pytest.mark.parametrize("elephant_fn,model_cls,parameter_name,empty_input", [ @@ -226,9 +232,9 @@ def test_invalid_empty_input(elephant_fn, model_cls, parameter_name, empty_input ]) def test_warning_empty_input(elephant_fn, model_cls, parameter_name, empty_input): warning = {parameter_name: empty_input} - with pytest.warns(Warning): + with pytest.warns(UserWarning): model_cls(**warning) - with pytest.warns(Warning): + with pytest.warns(UserWarning): elephant_fn(**warning) From 91df7872bf87446637e2dd3846dce678cc394e4d Mon Sep 17 00:00:00 2001 From: Jan Nolten Date: Mon, 24 Nov 2025 16:20:38 +0100 Subject: [PATCH 27/29] Forgot to remove a print statement --- elephant/schemas/function_validator.py | 1 - 1 file changed, 1 deletion(-) diff --git a/elephant/schemas/function_validator.py b/elephant/schemas/function_validator.py index 4300d0aa4..e388339a4 100644 --- a/elephant/schemas/function_validator.py +++ b/elephant/schemas/function_validator.py @@ -15,7 +15,6 @@ def decorator(func): @wraps(func) def wrapper(*args, **kwargs): - print(skip_validation) if not skip_validation: # Bind args & kwargs to function parameters bound = sig.bind_partial(*args, **kwargs) From 5dfa70724c06fbd4e67f03f7f7f2bbaccf3bd4e6 Mon Sep 17 00:00:00 2001 From: Jan Nolten Date: Mon, 5 Jan 2026 09:26:27 +0100 Subject: [PATCH 28/29] BugFix: t_start and t_stop should be able to be Nonw when the spiketrain is a neo.SpikeTrain --- elephant/schemas/schema_statistics.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/elephant/schemas/schema_statistics.py b/elephant/schemas/schema_statistics.py index 9d79644da..a5de15b62 100644 --- a/elephant/schemas/schema_statistics.py +++ b/elephant/schemas/schema_statistics.py @@ -44,11 +44,12 @@ def validate_time(cls, v, info): @model_validator(mode="after") def validate_model(self): - if isinstance(self.spiketrain, (np.ndarray, list)): + if isinstance(self.spiketrain, (neo.SpikeTrain, pq.Quantity)): + if not ((self.t_start is None or isinstance(self.t_start, pq.Quantity)) and (self.t_stop is None or isinstance(self.t_stop, pq.Quantity))): + raise TypeError("spiketrain is a neo.SpikeTrain or pq.Quantity but t_start or t_stop is not pq.Quantity") + elif isinstance(self.spiketrain, (np.ndarray, list)): if isinstance(self.t_start, pq.Quantity) or isinstance(self.t_stop, pq.Quantity): raise TypeError("spiketrain is a np.ndarray or list but t_start or t_stop is pq.Quantity") - elif not (isinstance(self.t_start, pq.Quantity) and isinstance(self.t_stop, pq.Quantity)): - raise TypeError("spiketrain is a neo.SpikeTrain or pq.Quantity but t_start or t_stop is not pq.Quantity") return self class PydanticInstantaneousRate(BaseModel): From c85e1ea85d55de29e1c18492138cb190474b12cb Mon Sep 17 00:00:00 2001 From: Jan Nolten Date: Mon, 5 Jan 2026 09:50:39 +0100 Subject: [PATCH 29/29] BugFix: t_start and t_stop are also allowed to be floats --- elephant/schemas/schema_statistics.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/elephant/schemas/schema_statistics.py b/elephant/schemas/schema_statistics.py index a5de15b62..7b7030f3a 100644 --- a/elephant/schemas/schema_statistics.py +++ b/elephant/schemas/schema_statistics.py @@ -45,8 +45,8 @@ def validate_time(cls, v, info): @model_validator(mode="after") def validate_model(self): if isinstance(self.spiketrain, (neo.SpikeTrain, pq.Quantity)): - if not ((self.t_start is None or isinstance(self.t_start, pq.Quantity)) and (self.t_stop is None or isinstance(self.t_stop, pq.Quantity))): - raise TypeError("spiketrain is a neo.SpikeTrain or pq.Quantity but t_start or t_stop is not pq.Quantity") + if not ((self.t_start is None or isinstance(self.t_start, (pq.Quantity, float))) and (self.t_stop is None or isinstance(self.t_stop, (pq.Quantity, float)))): + raise TypeError("spiketrain is a neo.SpikeTrain or pq.Quantity but t_start or t_stop is not pq.Quantity or float") elif isinstance(self.spiketrain, (np.ndarray, list)): if isinstance(self.t_start, pq.Quantity) or isinstance(self.t_stop, pq.Quantity): raise TypeError("spiketrain is a np.ndarray or list but t_start or t_stop is pq.Quantity")