diff --git a/elephant/gpfa/gpfa.py b/elephant/gpfa/gpfa.py index 85681172d..c9e28f076 100644 --- a/elephant/gpfa/gpfa.py +++ b/elephant/gpfa/gpfa.py @@ -271,7 +271,7 @@ def binsize(self): warnings.warn("'binsize' is deprecated; use 'bin_size'") return self.bin_size - def fit(self, spiketrains): + def fit(self, spiketrains, seqs_train=None): """ Fit the model with the given training data. @@ -287,6 +287,15 @@ def fit(self, spiketrains): `spiketrains[k][n]` refer to spike trains of the same neuron for any choices of `l`, `k`, and `n`. + seqs_train: np.recarray + Alternatively, pass a pre-processed seqs_train array. + This is a training data structure, whose n-th element (corresponding to + the n-th experimental trial) has fields + T : int + number of bins + y : (#units, T) np.ndarray + neural data + Returns ------- self : object @@ -301,8 +310,17 @@ def fit(self, spiketrains): If covariance matrix of input spike data is rank deficient. """ - self._check_training_data(spiketrains) - seqs_train = self._format_training_data(spiketrains) + + if seqs_train is not None and spiketrains is not None: + raise ValueError('Cannot provide both spiketrains and seqs_train!') + elif spiketrains is not None: + self._check_training_data(spiketrains) + seqs_train = self._format_training_data(spiketrains) + elif seqs_train is not None: + seqs_train = self._format_training_data_seqs(seqs_train) # TODO: write this! + else: + raise ValueError('Must supply either spiketrains or seqs_train!') + # Check if training data covariance is full rank y_all = np.hstack(seqs_train['y']) y_dim = y_all.shape[0] @@ -351,7 +369,14 @@ def _format_training_data(self, spiketrains): seq['y'] = seq['y'][self.has_spikes_bool, :] return seqs - def transform(self, spiketrains, returned_data=['latent_variable_orth']): + def _format_training_data_seqs(self, seqs): + # Remove inactive units based on training set + self.has_spikes_bool = np.hstack(seqs['y']).any(axis=1) + for seq in seqs: + seq['y'] = seq['y'][self.has_spikes_bool, :] + return seqs + + def transform(self, spiketrains, seqs=None, returned_data=['latent_variable_orth']): """ Obtain trajectories of neural activity in a low-dimensional latent variable space by inferring the posterior mean of the obtained GPFA @@ -424,14 +449,21 @@ def transform(self, spiketrains, returned_data=['latent_variable_orth']): If `returned_data` contains keys different from the ones in `self.valid_data_names`. """ - if len(spiketrains[0]) != len(self.has_spikes_bool): - raise ValueError("'spiketrains' must contain the same number of " - "neurons as the training spiketrain data") + invalid_keys = set(returned_data).difference(self.valid_data_names) if len(invalid_keys) > 0: raise ValueError("'returned_data' can only have the following " - "entries: {}".format(self.valid_data_names)) - seqs = gpfa_util.get_seqs(spiketrains, self.bin_size) + "entries: {}".format(self.valid_data_names)) + + if spiketrains is not None: + if len(spiketrains[0]) != len(self.has_spikes_bool): + raise ValueError("'spiketrains' must contain the same number of " + "neurons as the training spiketrain data") + seqs = gpfa_util.get_seqs(spiketrains, self.bin_size) + elif seqs is not None: + # check some stuff + pass + for seq in seqs: seq['y'] = seq['y'][self.has_spikes_bool, :] seqs, ll = gpfa_core.exact_inference_with_ll(seqs,