diff --git a/bnd/pipeline/pyaldata.py b/bnd/pipeline/pyaldata.py index 7f271b3..b0ee052 100644 --- a/bnd/pipeline/pyaldata.py +++ b/bnd/pipeline/pyaldata.py @@ -736,9 +736,6 @@ def expand_dim_in_single_bin_trials(self, column_subset="_spikes") -> None: column_subset : String expression to look for in columns to be expanded. Defaults to 'spikes_' - Returns - ------- - """ def _expand_dim_in_single_bin_trial(value): @@ -752,13 +749,24 @@ def _expand_dim_in_single_bin_trial(value): return + def drop_empty_states_at_end(self) -> None: + """ + Drop last column if all spike fields are empty + + """ + spike_cols = [col for col in self.pyaldata_df.columns if col.endswith("_spikes")] + final_state = self.pyaldata_df.iloc[-1] + all_zero = all(final_state[spike_col].size == 0 for spike_col in spike_cols) + + if all_zero: + self.pyaldata_df.drop(self.pyaldata_df.index[-1], inplace=True) + + return + def run_conversion(self): """ Main routine for pyaldata conversion - Returns - ------- - """ # Define all the necessary columns @@ -794,6 +802,9 @@ def run_conversion(self): # Expand dimensions self.expand_dim_in_single_bin_trials() + # Drop empty states at the end + self.drop_empty_states_at_end() + logger.info("Session converted to pyaldata format") return