-
Notifications
You must be signed in to change notification settings - Fork 96
Allow passing spike rates directly to gpfa #507
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -271,7 +271,7 @@ def binsize(self): | |||||||||||
| warnings.warn("'binsize' is deprecated; use 'bin_size'") | ||||||||||||
| return self.bin_size | ||||||||||||
|
|
||||||||||||
| def fit(self, spiketrains): | ||||||||||||
| def fit(self, spiketrains, seqs_train=None): | ||||||||||||
| """ | ||||||||||||
| Fit the model with the given training data. | ||||||||||||
|
|
||||||||||||
|
|
@@ -287,6 +287,15 @@ def fit(self, spiketrains): | |||||||||||
| `spiketrains[k][n]` refer to spike trains of the same neuron | ||||||||||||
| for any choices of `l`, `k`, and `n`. | ||||||||||||
|
|
||||||||||||
| seqs_train: np.recarray | ||||||||||||
| Alternatively, pass a pre-processed seqs_train array. | ||||||||||||
| This is a training data structure, whose n-th element (corresponding to | ||||||||||||
| the n-th experimental trial) has fields | ||||||||||||
| T : int | ||||||||||||
| number of bins | ||||||||||||
| y : (#units, T) np.ndarray | ||||||||||||
| neural data | ||||||||||||
|
|
||||||||||||
| Returns | ||||||||||||
| ------- | ||||||||||||
| self : object | ||||||||||||
|
|
@@ -301,8 +310,17 @@ def fit(self, spiketrains): | |||||||||||
|
|
||||||||||||
| If covariance matrix of input spike data is rank deficient. | ||||||||||||
| """ | ||||||||||||
| self._check_training_data(spiketrains) | ||||||||||||
| seqs_train = self._format_training_data(spiketrains) | ||||||||||||
|
|
||||||||||||
| if seqs_train is not None and spiketrains is not None: | ||||||||||||
| raise ValueError('Cannot provide both spiketrains and seqs_train!') | ||||||||||||
| elif spiketrains is not None: | ||||||||||||
| self._check_training_data(spiketrains) | ||||||||||||
| seqs_train = self._format_training_data(spiketrains) | ||||||||||||
| elif seqs_train is not None: | ||||||||||||
| seqs_train = self._format_training_data_seqs(seqs_train) # TODO: write this! | ||||||||||||
| else: | ||||||||||||
| raise ValueError('Must supply either spiketrains or seqs_train!') | ||||||||||||
|
|
||||||||||||
| # Check if training data covariance is full rank | ||||||||||||
| y_all = np.hstack(seqs_train['y']) | ||||||||||||
| y_dim = y_all.shape[0] | ||||||||||||
|
|
@@ -351,7 +369,14 @@ def _format_training_data(self, spiketrains): | |||||||||||
| seq['y'] = seq['y'][self.has_spikes_bool, :] | ||||||||||||
| return seqs | ||||||||||||
|
|
||||||||||||
| def transform(self, spiketrains, returned_data=['latent_variable_orth']): | ||||||||||||
| def _format_training_data_seqs(self, seqs): | ||||||||||||
| # Remove inactive units based on training set | ||||||||||||
| self.has_spikes_bool = np.hstack(seqs['y']).any(axis=1) | ||||||||||||
| for seq in seqs: | ||||||||||||
| seq['y'] = seq['y'][self.has_spikes_bool, :] | ||||||||||||
| return seqs | ||||||||||||
|
|
||||||||||||
| def transform(self, spiketrains, seqs=None, returned_data=['latent_variable_orth']): | ||||||||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since |
||||||||||||
| """ | ||||||||||||
| Obtain trajectories of neural activity in a low-dimensional latent | ||||||||||||
| variable space by inferring the posterior mean of the obtained GPFA | ||||||||||||
|
|
@@ -424,14 +449,21 @@ def transform(self, spiketrains, returned_data=['latent_variable_orth']): | |||||||||||
| If `returned_data` contains keys different from the ones in | ||||||||||||
| `self.valid_data_names`. | ||||||||||||
| """ | ||||||||||||
| if len(spiketrains[0]) != len(self.has_spikes_bool): | ||||||||||||
| raise ValueError("'spiketrains' must contain the same number of " | ||||||||||||
| "neurons as the training spiketrain data") | ||||||||||||
|
|
||||||||||||
| invalid_keys = set(returned_data).difference(self.valid_data_names) | ||||||||||||
| if len(invalid_keys) > 0: | ||||||||||||
| raise ValueError("'returned_data' can only have the following " | ||||||||||||
| "entries: {}".format(self.valid_data_names)) | ||||||||||||
| seqs = gpfa_util.get_seqs(spiketrains, self.bin_size) | ||||||||||||
| "entries: {}".format(self.valid_data_names)) | ||||||||||||
|
|
||||||||||||
| if spiketrains is not None: | ||||||||||||
| if len(spiketrains[0]) != len(self.has_spikes_bool): | ||||||||||||
| raise ValueError("'spiketrains' must contain the same number of " | ||||||||||||
| "neurons as the training spiketrain data") | ||||||||||||
| seqs = gpfa_util.get_seqs(spiketrains, self.bin_size) | ||||||||||||
| elif seqs is not None: | ||||||||||||
| # check some stuff | ||||||||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
Thanks for your suggestion, I took the liberty to add it here, I hope I got the spirit of your idea correctly? #507 (comment) |
||||||||||||
| pass | ||||||||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
no longer needed, see above |
||||||||||||
|
|
||||||||||||
| for seq in seqs: | ||||||||||||
| seq['y'] = seq['y'][self.has_spikes_bool, :] | ||||||||||||
| seqs, ll = gpfa_core.exact_inference_with_ll(seqs, | ||||||||||||
|
|
||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove TODO, since
_format_training_data_seqsis implemented, or is this referring to something different?