diff --git a/whobpyt/models/jansen_rit/jansen_rit.py b/whobpyt/models/jansen_rit/jansen_rit.py index 67029554..c847a9d5 100644 --- a/whobpyt/models/jansen_rit/jansen_rit.py +++ b/whobpyt/models/jansen_rit/jansen_rit.py @@ -442,22 +442,24 @@ def forward(self, external, hx, hE): # For each sample point, run the model by solving the differential # equations for a defined number of integration steps, # and keep only the final activity state within this set of steps - for step_i in range(self.steps_per_TR): - - # Collect the delayed inputs: - # i) index the history of E - Ed = pttranspose(hE.clone().gather(1,self.delays), 0, 1) + + + # i) index the history of E + Ed = pttranspose(hE.clone().gather(1,self.delays), 0, 1) - # ii) multiply the past states by the connectivity weights matrix, and sum over rows - LEd_p2e = ptsum(w_n_f * Ed, 1) - LEd_p2i = -ptsum(w_n_b * Ed, 1) - LEd_p2p = ptsum(w_n_l * Ed, 1) + # ii) multiply the past states by the connectivity weights matrix, and sum over rows + LEd_p2e = ptsum(w_n_f * Ed, 1) + LEd_p2i = -ptsum(w_n_b * Ed, 1) + LEd_p2p = ptsum(w_n_l * Ed, 1) + + # iii) reshape for next step + LEd_p2e = ptreshape(LEd_p2e, (n_nodes, 1)) + LEd_p2i = ptreshape(LEd_p2i, (n_nodes, 1)) + LEd_p2p = ptreshape(LEd_p2p, (n_nodes, 1)) + for step_i in range(self.steps_per_TR): - # iii) reshape for next step - LEd_p2e = ptreshape(LEd_p2e, (n_nodes, 1)) - LEd_p2i = ptreshape(LEd_p2i, (n_nodes, 1)) - LEd_p2p = ptreshape(LEd_p2p, (n_nodes, 1)) + # Collect the delayed inputs: # iv) if specified, add the laplacian component (self-connections from diagonals) if self.use_laplacian: