Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 41 additions & 9 deletions elephant/gpfa/gpfa.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ def binsize(self):
warnings.warn("'binsize' is deprecated; use 'bin_size'")
return self.bin_size

def fit(self, spiketrains):
def fit(self, spiketrains, seqs_train=None):
"""
Fit the model with the given training data.

Expand All @@ -287,6 +287,15 @@ def fit(self, spiketrains):
`spiketrains[k][n]` refer to spike trains of the same neuron
for any choices of `l`, `k`, and `n`.

seqs_train: np.recarray
Alternatively, pass a pre-processed seqs_train array.
This is a training data structure, whose n-th element (corresponding to
the n-th experimental trial) has fields
T : int
number of bins
y : (#units, T) np.ndarray
neural data

Returns
-------
self : object
Expand All @@ -301,8 +310,17 @@ def fit(self, spiketrains):

If covariance matrix of input spike data is rank deficient.
"""
self._check_training_data(spiketrains)
seqs_train = self._format_training_data(spiketrains)

if seqs_train is not None and spiketrains is not None:
raise ValueError('Cannot provide both spiketrains and seqs_train!')
elif spiketrains is not None:
self._check_training_data(spiketrains)
seqs_train = self._format_training_data(spiketrains)
elif seqs_train is not None:
seqs_train = self._format_training_data_seqs(seqs_train) # TODO: write this!

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
seqs_train = self._format_training_data_seqs(seqs_train) # TODO: write this!
seqs_train = self._format_training_data_seqs(seqs_train)

remove TODO, since _format_training_data_seqs is implemented, or is this referring to something different?

else:
raise ValueError('Must supply either spiketrains or seqs_train!')

# Check if training data covariance is full rank
y_all = np.hstack(seqs_train['y'])
y_dim = y_all.shape[0]
Expand Down Expand Up @@ -351,7 +369,14 @@ def _format_training_data(self, spiketrains):
seq['y'] = seq['y'][self.has_spikes_bool, :]
return seqs

def transform(self, spiketrains, returned_data=['latent_variable_orth']):
def _format_training_data_seqs(self, seqs):
# Remove inactive units based on training set
self.has_spikes_bool = np.hstack(seqs['y']).any(axis=1)
for seq in seqs:
seq['y'] = seq['y'][self.has_spikes_bool, :]
return seqs

def transform(self, spiketrains, seqs=None, returned_data=['latent_variable_orth']):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since seqs is now a parameter of transform, consider adding a description to the docstring of transform.

"""
Obtain trajectories of neural activity in a low-dimensional latent
variable space by inferring the posterior mean of the obtained GPFA
Expand Down Expand Up @@ -424,14 +449,21 @@ def transform(self, spiketrains, returned_data=['latent_variable_orth']):
If `returned_data` contains keys different from the ones in
`self.valid_data_names`.
"""
if len(spiketrains[0]) != len(self.has_spikes_bool):
raise ValueError("'spiketrains' must contain the same number of "
"neurons as the training spiketrain data")

invalid_keys = set(returned_data).difference(self.valid_data_names)
if len(invalid_keys) > 0:
raise ValueError("'returned_data' can only have the following "
"entries: {}".format(self.valid_data_names))
seqs = gpfa_util.get_seqs(spiketrains, self.bin_size)
"entries: {}".format(self.valid_data_names))

if spiketrains is not None:
if len(spiketrains[0]) != len(self.has_spikes_bool):
raise ValueError("'spiketrains' must contain the same number of "
"neurons as the training spiketrain data")
seqs = gpfa_util.get_seqs(spiketrains, self.bin_size)
elif seqs is not None:
# check some stuff
Copy link
Member

@Moritz-Alexander-Kern Moritz-Alexander-Kern Aug 10, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# check some stuff
if len(seqs['y'][0]) != len(self.has_spikes_bool):
raise ValueError(
"'seq_trains' must contain the same number of neurons as "
"the training spiketrain data")

Thanks for your suggestion, I took the liberty to add it here, I hope I got the spirit of your idea correctly? #507 (comment)

pass

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
pass

no longer needed, see above


for seq in seqs:
seq['y'] = seq['y'][self.has_spikes_bool, :]
seqs, ll = gpfa_core.exact_inference_with_ll(seqs,
Expand Down