diff --git a/PyCytoData/data.py b/PyCytoData/data.py index 6eacf9e..ec0c14e 100644 --- a/PyCytoData/data.py +++ b/PyCytoData/data.py @@ -417,6 +417,8 @@ def subset(self, channels: Optional[ArrayLike]=None, sample: Optional[ArrayLike] self.channels = self.channels[channel_filter_condition] if self.lineage_channels is not None: self.lineage_channels = self.lineage_channels[np.isin(self.lineage_channels, self.channels)] + else: + self._lineage_channels_indices = np.arange(self.n_channels) self.sample_index = self.sample_index[filter_condition] self.cell_types = self.cell_types[filter_condition] @@ -786,6 +788,7 @@ def lineage_channels(self, lineage_channels: ArrayLike): if not np.all(np.isin(lineage_channels, self._channels)): raise ValueError("Some lineage channels are not listed in channel names.") self._lineage_channels: Optional[np.ndarray] = lineage_channels if lineage_channels is None else np.array(lineage_channels).flatten() + self._lineage_channels_indices = np.where(np.isin(self.lineage_channels, self.channels)) @property diff --git a/tests/test_data.py b/tests/test_data.py index 2764822..3fdb5e5 100644 --- a/tests/test_data.py +++ b/tests/test_data.py @@ -359,7 +359,21 @@ def test_subset_channels(self): assert exprs.n_channels == 2 assert exprs.lineage_channels is not None assert exprs.lineage_channels.shape[0] == 2 + assert np.all(np.equal(exprs._lineage_channels_indices, np.array([0, 1]))) assert not np.isin("Channel0", exprs.lineage_channels) + + + def test_subset_channels_lineage_indices(self): + exprs_matrix: np.ndarray = np.random.rand(100, 10) + channels: np.ndarray = np.arange(10).astype(str) + exprs = PyCytoData(exprs_matrix, channels=channels) + exprs.subset(channels=["1", "2"]) + assert exprs.n_cells == 100 + assert exprs.n_channels == 2 + assert exprs.lineage_channels is None + assert exprs._lineage_channels_indices.shape[0] == 2 + assert np.all(np.equal(exprs._lineage_channels_indices, np.array([0, 1]))) + assert not np.isin("0", exprs.lineage_channels) def test_subset_cell_types(self):