@@ -58,9 +58,11 @@ def create(root_folder: str, cache_data: bool = False, **kwargs) -> "Interpolato
5858 return SequenceInterpolator (root_folder , cache_data , ** kwargs )
5959 elif modality == "screen" :
6060 return ScreenInterpolator (root_folder , cache_data , ** kwargs )
61+ elif modality == "time_interval" :
62+ return TimeIntervalInterpolator (root_folder , cache_data , ** kwargs )
6163 else :
6264 raise ValueError (
63- f"There is no interpolator for { modality } . Please use 'sequence' or 'screen ' as modality."
65+ f"There is no interpolator for { modality } . Please use 'sequence', 'screen' or 'time_interval ' as modality."
6466 )
6567
6668 def valid_times (self , times : np .ndarray ) -> np .ndarray :
@@ -451,7 +453,6 @@ def interpolate(self, times: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
451453 if ((len (data .shape ) == 2 ) or (data .shape [- 1 ] == 3 )) and (
452454 len (data .shape ) < 4
453455 ):
454-
455456 data = np .expand_dims (data , axis = 0 )
456457 idx_for_this_file = np .where (self ._data_file_idx [idx ] == u_idx )
457458 if self .rescale :
@@ -478,6 +479,97 @@ def rescale_frame(self, frame: np.array) -> np.array:
478479 )
479480
480481
482+ class TimeIntervalInterpolator (Interpolator ):
483+ def __init__ (self , root_folder : str , cache_data : bool = False , ** kwargs ):
484+ super ().__init__ (root_folder )
485+ self .cache_data = cache_data
486+
487+ meta = self .load_meta ()
488+ self .meta_labels = meta ["labels" ]
489+ self .start_time = meta ["start_time" ]
490+ self .end_time = meta ["end_time" ]
491+ self .valid_interval = TimeInterval (self .start_time , self .end_time )
492+
493+ if self .cache_data :
494+ self .labeled_intervals = {
495+ label : np .load (self .root_folder / filename )
496+ for label , filename in self .meta_labels .items ()
497+ }
498+
499+ def interpolate (self , times : np .ndarray ) -> np .ndarray :
500+ """
501+ Interpolate time intervals for labeled events.
502+
503+ Given a set of time points and a set of labeled intervals (defined in the
504+ `meta.yml` file), this method returns a boolean array indicating, for each
505+ time point, whether it falls within any interval for each label.
506+
507+ The method uses half-open intervals [start, end), where a timestamp t is
508+ considered to fall within an interval if start <= t < end. This means the
509+ start time is inclusive and the end time is exclusive.
510+
511+ Parameters
512+ ----------
513+ times : np.ndarray
514+ Array of time points to be checked against the labeled intervals.
515+
516+ Returns
517+ -------
518+ out : np.ndarray of bool, shape (len(valid_times), n_labels)
519+ Boolean array where each row corresponds to a valid time point and each
520+ column corresponds to a label. `out[i, j]` is True if the i-th valid
521+ time falls within any interval for the j-th label, and False otherwise.
522+
523+ Notes
524+ -----
525+ - The labels and their corresponding intervals are defined in the `meta.yml`
526+ file under the `labels` key. Each label points to a `.npy` file containing
527+ an array of shape (n, 2), where each row is a [start, end) time interval.
528+ - Typical labels might include 'train', 'validation', 'test', 'saccade',
529+ 'gaze', or 'target'.
530+ - Only time points within the valid interval (as defined by start_time and
531+ end_time in meta.yml) are considered; others are filtered out.
532+ - Intervals where start > end are considered invalid and will trigger a
533+ warning.
534+ """
535+ valid = self .valid_times (times )
536+ valid_times = times [valid ]
537+
538+ n_labels = len (self .meta_labels )
539+ n_times = len (valid_times )
540+
541+ if n_times == 0 :
542+ warnings .warn (
543+ "TimeIntervalInterpolator returns an empty array, no valid times queried."
544+ )
545+ return np .empty ((0 , n_labels ), dtype = bool )
546+
547+ out = np .zeros ((n_times , n_labels ), dtype = bool )
548+ for i , (label , filename ) in enumerate (self .meta_labels .items ()):
549+ if self .cache_data :
550+ intervals = self .labeled_intervals [label ]
551+ else :
552+ intervals = np .load (self .root_folder / filename , allow_pickle = True )
553+
554+ if len (intervals ) == 0 :
555+ warnings .warn (
556+ f"TimeIntervalInterpolator found no intervals for label: { label } "
557+ )
558+ continue
559+
560+ for start , end in intervals :
561+ if start > end :
562+ warnings .warn (
563+ f"Invalid interval found for label: { label } , interval: ({ start } , { end } )"
564+ )
565+ continue
566+ # Half-open interval [start, end): inclusive start, exclusive end
567+ mask = (valid_times >= start ) & (valid_times < end )
568+ out [mask , i ] = True
569+
570+ return out
571+
572+
481573class ScreenTrial :
482574 def __init__ (
483575 self ,
@@ -548,7 +640,6 @@ def __init__(self, data_file_name, meta_data, cache_data: bool = False) -> None:
548640
549641class BlankTrial (ScreenTrial ):
550642 def __init__ (self , data_file_name , meta_data , cache_data : bool = False ) -> None :
551-
552643 self .interleave_value = meta_data .get ("interleave_value" )
553644
554645 super ().__init__ (
@@ -567,7 +658,6 @@ def get_data_(self) -> np.array:
567658
568659class InvalidTrial (ScreenTrial ):
569660 def __init__ (self , data_file_name , meta_data , cache_data : bool = False ) -> None :
570-
571661 self .interleave_value = meta_data .get ("interleave_value" )
572662
573663 super ().__init__ (
0 commit comments