diff --git a/biolearn/model.py b/biolearn/model.py index 9f60cb4..e4edf3a 100644 --- a/biolearn/model.py +++ b/biolearn/model.py @@ -1492,6 +1492,7 @@ def no_transform(_): def predict(self, geo_data): matrix_data = self._get_data_matrix(geo_data) matrix_data = self.preprocess(matrix_data) + self._validate_required_features(matrix_data) matrix_data.loc["intercept"] = 1 # Join the coefficients and dnam_data on the index @@ -1513,14 +1514,42 @@ def predict(self, geo_data): # Return as a DataFrame return result.apply(self.transform).to_frame(name="Predicted") + def _validate_required_features(self, matrix_data): + return + def _get_data_matrix(self, geo_data): raise NotImplementedError() class LinearMethylationModel(LinearModel): + _MISSING_CPG_PREVIEW_LIMIT = 5 + def _get_data_matrix(self, geo_data): return geo_data.dnam + def _validate_required_features(self, matrix_data): + required_cpgs = self.methylation_sites() + missing_cpgs = sorted(set(required_cpgs) - set(matrix_data.index)) + if not missing_cpgs: + return + + model_name = self.details.get("name") + model_label = f" for model '{model_name}'" if model_name else "" + preview_limit = self._MISSING_CPG_PREVIEW_LIMIT + preview = ", ".join(missing_cpgs[:preview_limit]) + remaining = len(missing_cpgs) - preview_limit + if remaining > 0: + preview = ( + f"showing first {preview_limit}: {preview} (+{remaining} more)" + ) + + raise ValueError( + "Missing required CpG sites" + f"{model_label} ({len(missing_cpgs)}/{len(required_cpgs)}): " + f"{preview}. " + "Provide methylation data with these CpGs or use an imputation method that includes them." + ) + def methylation_sites(self): unique_vars = set(self.coefficients.index) - {"intercept"} return list(unique_vars) diff --git a/biolearn/test/test_model.py b/biolearn/test/test_model.py index 0fbf3db..df4fbb5 100644 --- a/biolearn/test/test_model.py +++ b/biolearn/test/test_model.py @@ -14,6 +14,8 @@ TOLERANCES = defaultdict(lambda: 1e-5) # AltumAge: TF→Torch port can differ ~1e-4 across platforms/versions. TOLERANCES["AltumAge"] = 2e-4 +# MiAge: iterative algorithm can drift ~3e-5 across Python/platform versions. +TOLERANCES["MiAge"] = 5e-5 sample_inputs = load_test_data_file("testset/testset_methylation_part0.csv") @@ -50,6 +52,18 @@ def test_models(model_name, model_entry): # Instantiate the model test_model = model_class.from_definition(model_entry) + if model_type == "LinearMethylationModel": + required_cpgs = test_model.methylation_sites() + missing_cpgs = sorted(set(required_cpgs) - set(test_data.dnam.index)) + if missing_cpgs: + missing_preview = missing_cpgs[0] + with pytest.raises( + ValueError, + match=rf"Missing required CpG sites.*{missing_preview}", + ): + test_model.predict(test_data) + return + actual_results = test_model.predict(test_data).sort_index() # Load the expected results