Skip to content

Commit ee767fb

Browse files
authored
Merge pull request #84 from oliveira-caio/tier_modality
added TimeIntervalInterpolator and tests for it.
2 parents ffafe73 + 65aa46c commit ee767fb

File tree

5 files changed

+688
-9
lines changed

5 files changed

+688
-9
lines changed

experanto/filters/common_filters.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import numpy as np
22

3-
from experanto.interpolators import Interpolator
3+
from experanto.interpolators import SequenceInterpolator
44
from experanto.intervals import (
55
TimeInterval,
66
find_complement_of_interval_array,
@@ -9,8 +9,9 @@
99

1010

1111
def nan_filter(vicinity=0.05):
12-
13-
def implementation(device_: Interpolator):
12+
def implementation(device_: SequenceInterpolator):
13+
# requests SequenceInterpolator as uses time_delta internally
14+
# and other interpolators don't have it
1415
time_delta = device_.time_delta
1516
start_time = device_.start_time
1617
end_time = device_.end_time

experanto/interpolators.py

Lines changed: 94 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -58,9 +58,11 @@ def create(root_folder: str, cache_data: bool = False, **kwargs) -> "Interpolato
5858
return SequenceInterpolator(root_folder, cache_data, **kwargs)
5959
elif modality == "screen":
6060
return ScreenInterpolator(root_folder, cache_data, **kwargs)
61+
elif modality == "time_interval":
62+
return TimeIntervalInterpolator(root_folder, cache_data, **kwargs)
6163
else:
6264
raise ValueError(
63-
f"There is no interpolator for {modality}. Please use 'sequence' or 'screen' as modality."
65+
f"There is no interpolator for {modality}. Please use 'sequence', 'screen' or 'time_interval' as modality."
6466
)
6567

6668
def valid_times(self, times: np.ndarray) -> np.ndarray:
@@ -451,7 +453,6 @@ def interpolate(self, times: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
451453
if ((len(data.shape) == 2) or (data.shape[-1] == 3)) and (
452454
len(data.shape) < 4
453455
):
454-
455456
data = np.expand_dims(data, axis=0)
456457
idx_for_this_file = np.where(self._data_file_idx[idx] == u_idx)
457458
if self.rescale:
@@ -478,6 +479,97 @@ def rescale_frame(self, frame: np.array) -> np.array:
478479
)
479480

480481

482+
class TimeIntervalInterpolator(Interpolator):
483+
def __init__(self, root_folder: str, cache_data: bool = False, **kwargs):
484+
super().__init__(root_folder)
485+
self.cache_data = cache_data
486+
487+
meta = self.load_meta()
488+
self.meta_labels = meta["labels"]
489+
self.start_time = meta["start_time"]
490+
self.end_time = meta["end_time"]
491+
self.valid_interval = TimeInterval(self.start_time, self.end_time)
492+
493+
if self.cache_data:
494+
self.labeled_intervals = {
495+
label: np.load(self.root_folder / filename)
496+
for label, filename in self.meta_labels.items()
497+
}
498+
499+
def interpolate(self, times: np.ndarray) -> np.ndarray:
500+
"""
501+
Interpolate time intervals for labeled events.
502+
503+
Given a set of time points and a set of labeled intervals (defined in the
504+
`meta.yml` file), this method returns a boolean array indicating, for each
505+
time point, whether it falls within any interval for each label.
506+
507+
The method uses half-open intervals [start, end), where a timestamp t is
508+
considered to fall within an interval if start <= t < end. This means the
509+
start time is inclusive and the end time is exclusive.
510+
511+
Parameters
512+
----------
513+
times : np.ndarray
514+
Array of time points to be checked against the labeled intervals.
515+
516+
Returns
517+
-------
518+
out : np.ndarray of bool, shape (len(valid_times), n_labels)
519+
Boolean array where each row corresponds to a valid time point and each
520+
column corresponds to a label. `out[i, j]` is True if the i-th valid
521+
time falls within any interval for the j-th label, and False otherwise.
522+
523+
Notes
524+
-----
525+
- The labels and their corresponding intervals are defined in the `meta.yml`
526+
file under the `labels` key. Each label points to a `.npy` file containing
527+
an array of shape (n, 2), where each row is a [start, end) time interval.
528+
- Typical labels might include 'train', 'validation', 'test', 'saccade',
529+
'gaze', or 'target'.
530+
- Only time points within the valid interval (as defined by start_time and
531+
end_time in meta.yml) are considered; others are filtered out.
532+
- Intervals where start > end are considered invalid and will trigger a
533+
warning.
534+
"""
535+
valid = self.valid_times(times)
536+
valid_times = times[valid]
537+
538+
n_labels = len(self.meta_labels)
539+
n_times = len(valid_times)
540+
541+
if n_times == 0:
542+
warnings.warn(
543+
"TimeIntervalInterpolator returns an empty array, no valid times queried."
544+
)
545+
return np.empty((0, n_labels), dtype=bool)
546+
547+
out = np.zeros((n_times, n_labels), dtype=bool)
548+
for i, (label, filename) in enumerate(self.meta_labels.items()):
549+
if self.cache_data:
550+
intervals = self.labeled_intervals[label]
551+
else:
552+
intervals = np.load(self.root_folder / filename, allow_pickle=True)
553+
554+
if len(intervals) == 0:
555+
warnings.warn(
556+
f"TimeIntervalInterpolator found no intervals for label: {label}"
557+
)
558+
continue
559+
560+
for start, end in intervals:
561+
if start > end:
562+
warnings.warn(
563+
f"Invalid interval found for label: {label}, interval: ({start}, {end})"
564+
)
565+
continue
566+
# Half-open interval [start, end): inclusive start, exclusive end
567+
mask = (valid_times >= start) & (valid_times < end)
568+
out[mask, i] = True
569+
570+
return out
571+
572+
481573
class ScreenTrial:
482574
def __init__(
483575
self,
@@ -548,7 +640,6 @@ def __init__(self, data_file_name, meta_data, cache_data: bool = False) -> None:
548640

549641
class BlankTrial(ScreenTrial):
550642
def __init__(self, data_file_name, meta_data, cache_data: bool = False) -> None:
551-
552643
self.interleave_value = meta_data.get("interleave_value")
553644

554645
super().__init__(
@@ -567,7 +658,6 @@ def get_data_(self) -> np.array:
567658

568659
class InvalidTrial(ScreenTrial):
569660
def __init__(self, data_file_name, meta_data, cache_data: bool = False) -> None:
570-
571661
self.interleave_value = meta_data.get("interleave_value")
572662

573663
super().__init__(

experanto/intervals.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,9 @@ def intersect(self, times: np.ndarray) -> np.ndarray:
2828
return np.where((times >= self.start) & (times <= self.end))[0]
2929

3030

31-
def uniquefy_interval_array(interval_array: List[TimeInterval]) -> List[TimeInterval]:
31+
def uniquefy_interval_array(
32+
interval_array: List[TimeInterval],
33+
) -> List[TimeInterval]:
3234
"""
3335
Takes an array of TimeIntervals and returns a new array where no intervals overlap.
3436
If intervals overlap or are adjacent, they are merged into a single interval.
@@ -92,7 +94,6 @@ def find_intersection_between_two_interval_arrays(
9294
def find_intersection_across_arrays_of_intervals(
9395
intervals_array: List[List[TimeInterval]],
9496
) -> TimeInterval:
95-
9697
common_interval_array = intervals_array[0]
9798

9899
for interval_array in intervals_array[1:]:
Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
import shutil
2+
from contextlib import closing, contextmanager
3+
from pathlib import Path
4+
5+
import numpy as np
6+
import yaml
7+
8+
from experanto.interpolators import Interpolator
9+
10+
TIME_INTERVAL_ROOT = Path("tests/time_interval_data")
11+
12+
13+
@contextmanager
14+
def create_time_interval_data(
15+
duration=10.0,
16+
sampling_rate=30.0,
17+
test_intervals=None,
18+
train_intervals=None,
19+
validation_intervals=None,
20+
):
21+
"""
22+
Create time interval test data with non-integer timestamps.
23+
24+
Parameters
25+
----------
26+
duration : float
27+
Total duration of the recording in seconds.
28+
sampling_rate : float
29+
Sampling rate in Hz for generating timestamps.
30+
test_intervals : list of [start, end], optional
31+
List of time ranges for test label. Defaults to [[0.0, 2.0]].
32+
train_intervals : list of [start, end], optional
33+
List of time ranges for train label. Defaults to [[2.0, 4.0], [6.0, 8.0]].
34+
validation_intervals : list of [start, end], optional
35+
List of time ranges for validation label. Defaults to [[4.0, 6.0], [8.0, 10.0]].
36+
37+
Yields
38+
------
39+
timestamps : np.ndarray
40+
Array of timestamp values.
41+
intervals_dict : dict
42+
Dictionary mapping label names to their interval arrays.
43+
"""
44+
try:
45+
TIME_INTERVAL_ROOT.mkdir(parents=True, exist_ok=True)
46+
47+
# Default intervals
48+
if test_intervals is None:
49+
test_intervals = [[0.0, 2.0]]
50+
if train_intervals is None:
51+
train_intervals = [[2.0, 4.0], [6.0, 8.0]]
52+
if validation_intervals is None:
53+
validation_intervals = [[4.0, 6.0], [8.0, duration]]
54+
55+
# Generate non-integer timestamps
56+
n_samples = int(duration * sampling_rate)
57+
timestamps = np.linspace(0.0, duration, n_samples, endpoint=False)
58+
59+
# Create metadata
60+
meta = {
61+
"labels": {
62+
"test": "test.npy",
63+
"train": "train.npy",
64+
"validation": "validation.npy",
65+
},
66+
"start_time": 0.0,
67+
"end_time": duration,
68+
"modality": "time_interval",
69+
}
70+
71+
with open(TIME_INTERVAL_ROOT / "meta.yml", "w") as f:
72+
yaml.dump(meta, f)
73+
74+
# Save interval files
75+
test_array = np.array(test_intervals, dtype=np.float64)
76+
train_array = np.array(train_intervals, dtype=np.float64)
77+
validation_array = np.array(validation_intervals, dtype=np.float64)
78+
79+
np.save(TIME_INTERVAL_ROOT / "test.npy", test_array)
80+
np.save(TIME_INTERVAL_ROOT / "train.npy", train_array)
81+
np.save(TIME_INTERVAL_ROOT / "validation.npy", validation_array)
82+
83+
intervals_dict = {
84+
"test": test_array,
85+
"train": train_array,
86+
"validation": validation_array,
87+
}
88+
89+
yield timestamps, intervals_dict
90+
91+
finally:
92+
shutil.rmtree(TIME_INTERVAL_ROOT, ignore_errors=True)
93+
94+
95+
@contextmanager
96+
def time_interval_data_and_interpolator(data_kwargs=None, interp_kwargs=None):
97+
"""
98+
Create time interval test data and interpolator in one context manager.
99+
100+
This follows the pattern used in sequence_data_and_interpolator for consistency.
101+
102+
Parameters
103+
----------
104+
data_kwargs : dict, optional
105+
Keyword arguments to pass to create_time_interval_data.
106+
interp_kwargs : dict, optional
107+
Keyword arguments to pass to Interpolator.create.
108+
109+
Yields
110+
------
111+
timestamps : np.ndarray
112+
Array of timestamp values.
113+
intervals_dict : dict
114+
Dictionary mapping label names to their interval arrays.
115+
interpolator : TimeIntervalInterpolator
116+
The interpolator object.
117+
"""
118+
data_kwargs = data_kwargs or {}
119+
interp_kwargs = interp_kwargs or {}
120+
121+
with create_time_interval_data(**data_kwargs) as (
122+
timestamps,
123+
intervals_dict,
124+
):
125+
with closing(
126+
Interpolator.create("tests/time_interval_data", **interp_kwargs)
127+
) as time_interp:
128+
yield timestamps, intervals_dict, time_interp

0 commit comments

Comments
 (0)