diff --git a/TestRule/__init__.py b/TestRule/__init__.py index 0ccac2d0d..d0f88256d 100644 --- a/TestRule/__init__.py +++ b/TestRule/__init__.py @@ -6,6 +6,14 @@ from cdisc_rules_engine.services.cdisc_library_service import CDISCLibraryService from cdisc_rules_engine.services.cache.cache_populator_service import CachePopulator from scripts.run_validation import run_single_rule_validation +from cdisc_rules_engine.exceptions.custom_exceptions import ( + CTPackageNotFoundError, + LibraryMetadataNotFoundError, +) +from scripts.script_utils import library_metadata_not_found_message +from cdisc_library_client.custom_exceptions import ( + ResourceNotFoundException as LibraryResourceNotFoundException, +) import json import os import asyncio @@ -17,11 +25,13 @@ class BadRequestError(Exception): pass +_REQUIRED_DATASET_KEYS = {"filename", "label", "domain", "records", "variables"} + + def validate_datasets_payload(datasets): - required_keys = {"filename", "label", "domain", "records", "variables"} missing_keys = set() for dataset in datasets: - for key in required_keys: + for key in _REQUIRED_DATASET_KEYS: if key not in dataset: missing_keys.add(key) @@ -32,19 +42,39 @@ def validate_datasets_payload(datasets): ) if missing_keys: - raise KeyError( - f"one or more datasets missing the following keys {missing_keys}" - ) + raise BadRequestError("Test data is incorrect and missing required formatting.") def handle_exception(e: Exception): - if isinstance(e, KeyError): + if isinstance(e, BadRequestError): + return func.HttpResponse( + json.dumps({"error": "BadRequestError", "message": str(e)}), + status_code=400, + ) + if isinstance(e, LibraryMetadataNotFoundError): + msg = getattr(e, "message", None) or getattr(e, "description", None) or str(e) + return func.HttpResponse( + json.dumps( + { + "error": "LibraryMetadataNotFoundError", + "message": msg, + } + ), + status_code=400, + ) + if isinstance(e, CTPackageNotFoundError): + msg = getattr(e, "message", None) or getattr(e, "description", None) or str(e) return func.HttpResponse( - json.dumps({"error": "KeyError", "message": str(e)}), status_code=400 + json.dumps({"error": "CTPackageNotFoundError", "message": msg}), + status_code=400, ) - elif isinstance(e, BadRequestError): + if isinstance(e, KeyError): + msg = str(e) + if "rule" in msg.lower() or "datasets" in msg.lower(): + msg = f"{msg} Ensure the request body includes the required JSON keys." return func.HttpResponse( - json.dumps({"error": "BadRequestError", "message": str(e)}), status_code=400 + json.dumps({"error": "BadRequestError", "message": msg}), + status_code=400, ) else: return func.HttpResponse( @@ -97,12 +127,25 @@ def main(req: func.HttpRequest, context: func.Context) -> func.HttpResponse: # asyncio.run(cache_populator.load_available_ct_packages()) if standards_data or codelists: if standards_data: - asyncio.run( - cache_populator.load_standard( - standard, standard_version, standard_substandard + try: + asyncio.run( + cache_populator.load_standard( + standard, standard_version, standard_substandard + ) + ) + except LibraryResourceNotFoundException: + raise LibraryMetadataNotFoundError( + library_metadata_not_found_message( + standard, standard_version, standard_substandard + ) ) + try: + asyncio.run(cache_populator.load_codelists(codelists or [])) + except LibraryResourceNotFoundException: + raise CTPackageNotFoundError( + "Controlled terminology package(s) not found: " + f"{', '.join(str(c) for c in (codelists or []))}." ) - asyncio.run(cache_populator.load_codelists(codelists)) if not rule: raise KeyError("'rule' required in request") datasets = json_data.get("datasets") diff --git a/cdisc_rules_engine/enums/excel_test_sheets.py b/cdisc_rules_engine/enums/excel_test_sheets.py new file mode 100644 index 000000000..4e6d90581 --- /dev/null +++ b/cdisc_rules_engine/enums/excel_test_sheets.py @@ -0,0 +1,8 @@ +from cdisc_rules_engine.enums.base_enum import BaseEnum + + +class ExcelDataSheets(BaseEnum): + DATASETS_SHEET_NAME = "Datasets" + DATASET_FILENAME_COLUMN = "Filename" + DATASET_LABEL_COLUMN = "Label" + DATASETS_SHEET_REQUIRED_COLUMNS = ("Filename", "Label") diff --git a/cdisc_rules_engine/enums/standard_types.py b/cdisc_rules_engine/enums/standard_types.py new file mode 100644 index 000000000..f787d4afd --- /dev/null +++ b/cdisc_rules_engine/enums/standard_types.py @@ -0,0 +1,11 @@ +from cdisc_rules_engine.enums.base_enum import BaseEnum + + +class StandardTypes(BaseEnum): + """Standards supported by CDISC Library; used for CLI validation when not using --custom-standard.""" + + SDTMIG = "sdtmig" + SENDIG = "sendig" + ADAM = "adam" + TIG = "tig" + USDM = "usdm" diff --git a/cdisc_rules_engine/exceptions/custom_exceptions.py b/cdisc_rules_engine/exceptions/custom_exceptions.py index e3e5b79d5..669cb719c 100644 --- a/cdisc_rules_engine/exceptions/custom_exceptions.py +++ b/cdisc_rules_engine/exceptions/custom_exceptions.py @@ -40,6 +40,13 @@ class VariableMetadataNotFoundError(EngineError): ) +class LibraryMetadataNotFoundError(EngineError): + code = 400 + description = ( + "Library metadata not found for the provided standard and version combination." + ) + + class DomainNotFoundError(EngineError): """Raised when a required domain is not found in the dataset""" @@ -62,6 +69,19 @@ class InvalidJSONFormat(EngineError): description = "JSON data is malformed." +class ExcelTestDataError(EngineError): + code = 400 + description = ( + "Excel test data file is missing required sheets or column headers. " + "Sheet and column names are case-sensitive." + ) + + +class CTPackageNotFoundError(EngineError): + code = 400 + description = "Controlled terminology package(s) not found" + + class NumberOfAttemptsExceeded(EngineError): pass diff --git a/cdisc_rules_engine/services/data_services/excel_data_service.py b/cdisc_rules_engine/services/data_services/excel_data_service.py index dbfa60567..9127aeb89 100644 --- a/cdisc_rules_engine/services/data_services/excel_data_service.py +++ b/cdisc_rules_engine/services/data_services/excel_data_service.py @@ -1,5 +1,6 @@ import os from io import IOBase +import functools from typing import List, Sequence from datetime import datetime import re @@ -14,16 +15,14 @@ from cdisc_rules_engine.models.variable_metadata_container import ( VariableMetadataContainer, ) -from cdisc_rules_engine.services import logger +from cdisc_rules_engine.exceptions.custom_exceptions import ExcelTestDataError from cdisc_rules_engine.services.data_readers.data_reader_factory import ( DataReaderFactory, ) from .base_data_service import BaseDataService, cached_dataset - -DATASETS_SHEET_NAME = "Datasets" -DATASET_FILENAME_COLUMN = "Filename" -DATASET_LABEL_COLUMN = "Label" -DATASET_NAME_COLUMN = "Dataset Name" +from cdisc_rules_engine.enums.excel_test_sheets import ( + ExcelDataSheets, +) class ExcelDataService(BaseDataService): @@ -112,34 +111,43 @@ def get_dataset(self, dataset_name: str, **params) -> DatasetInterface: def _get_dataset_name( self, metadata: pd.DataFrame, first_record: dict, dataset_filename: str ) -> str: - if DATASET_NAME_COLUMN in metadata.columns and not metadata.empty: - return metadata[DATASET_NAME_COLUMN].iloc[0] if self.standard == "usdm": return first_record.get("instanceType", dataset_filename.split(".")[0]) return dataset_filename.split(".")[0].upper() + @functools.lru_cache(maxsize=None) + def _get_datasets_worksheet(self) -> pd.DataFrame: + return pd.read_excel( + self.dataset_path, + sheet_name=ExcelDataSheets.DATASETS_SHEET_NAME.value, + na_values=[""], + keep_default_na=False, + ) + @cached_dataset(DatasetTypes.RAW_METADATA.value) def get_raw_dataset_metadata( - self, dataset_name: str, **kwargs + self, + dataset_name: str, + **kwargs, ) -> SDTMDatasetMetadata: """ Returns dataset metadata as DatasetMetadata instance. """ - datasets_worksheet = pd.read_excel( - self.dataset_path, - sheet_name=DATASETS_SHEET_NAME, - na_values=[""], - keep_default_na=False, - ) + datasets_worksheet = self._get_datasets_worksheet() metadata = datasets_worksheet[ - datasets_worksheet[DATASET_FILENAME_COLUMN] == dataset_name + datasets_worksheet[ExcelDataSheets.DATASET_FILENAME_COLUMN.value] + == dataset_name ] dataset = self.get_dataset(dataset_name=dataset_name) first_record = dataset.data.iloc[0].to_dict() if not dataset.empty else {} return SDTMDatasetMetadata( name=self._get_dataset_name(metadata, first_record, dataset_name), first_record=first_record, - label=metadata[DATASET_LABEL_COLUMN].iloc[0] if not metadata.empty else "", + label=( + metadata[ExcelDataSheets.DATASET_LABEL_COLUMN.value].iloc[0] + if not metadata.empty + else "" + ), modification_date=datetime.fromtimestamp( os.path.getmtime(self.dataset_path) ).isoformat(), @@ -199,23 +207,41 @@ def read_data(self, file_path: str) -> IOBase: def get_datasets(self) -> List[dict]: try: - worksheet = pd.read_excel( - self.dataset_path, - sheet_name=DATASETS_SHEET_NAME, - na_values=[""], - keep_default_na=False, - ) - except TypeError as e: - logger.error( - f"Failed to read datasets from the Excel file at {self.dataset_path}. " - f"Ensure the file is in the correct format. " - f"Try opening and saving the file in Microsoft Excel. " - f"Error: {str(e)}" - ) + with pd.ExcelFile(self.dataset_path) as xl: + sheet_names = xl.sheet_names + if ExcelDataSheets.DATASETS_SHEET_NAME.value not in sheet_names: + available = ", ".join(repr(s) for s in sheet_names) or "(none)" + raise ExcelTestDataError( + f"The workbook does not contain a '{ExcelDataSheets.DATASETS_SHEET_NAME.value}' sheet. " + f"Submitted sheet names: {available}." + ) + worksheet = xl.parse( + ExcelDataSheets.DATASETS_SHEET_NAME.value, + na_values=[""], + keep_default_na=False, + ) + except ExcelTestDataError: raise + except Exception as e: + raise ExcelTestDataError( + f"Cannot read the Excel file. Ensure it is a valid .xlsx workbook. " + f"Details: {e}" + ) from e + + missing_cols = sorted( + set(ExcelDataSheets.DATASETS_SHEET_REQUIRED_COLUMNS.value) + - set(worksheet.columns) + ) + if missing_cols: + raise ExcelTestDataError( + f"The '{ExcelDataSheets.DATASETS_SHEET_NAME.value}' sheet is missing a " + f"required {ExcelDataSheets.DATASETS_SHEET_REQUIRED_COLUMNS.value} column(s): " + f"{missing_cols}. Column headers are case-sensitive. " + ) + datasets = [ - self.get_raw_dataset_metadata(dataset_name=dataset_filename) - for dataset_filename in worksheet[DATASET_FILENAME_COLUMN] + self.get_raw_dataset_metadata(dataset_name=fn) + for fn in worksheet[ExcelDataSheets.DATASET_FILENAME_COLUMN.value] ] return datasets diff --git a/cdisc_rules_engine/services/data_services/local_data_service.py b/cdisc_rules_engine/services/data_services/local_data_service.py index cffb61bd3..6f2408a51 100644 --- a/cdisc_rules_engine/services/data_services/local_data_service.py +++ b/cdisc_rules_engine/services/data_services/local_data_service.py @@ -24,11 +24,11 @@ convert_file_size, extract_file_name_from_path_string, ) +from cdisc_rules_engine.exceptions.custom_exceptions import InvalidDatasetFormat from .base_data_service import BaseDataService, cached_dataset from cdisc_rules_engine.enums.dataformat_types import DataFormatTypes from cdisc_rules_engine.models.dataset.dataset_interface import DatasetInterface from cdisc_rules_engine.models.dataset import PandasDataset -from cdisc_rules_engine.services import logger import re @@ -244,28 +244,12 @@ def get_datasets(self) -> List[dict]: dataset_name=dataset_path ) datasets.append(dataset_metadata) + except InvalidDatasetFormat: + raise except Exception as e: - logger.error( - f"Failed to read metadata for dataset {dataset_path}. " - f"Error: {type(e).__name__}: {e}. Skipping this dataset." - ) - file_name = extract_file_name_from_path_string(dataset_path) - datasets.append( - SDTMDatasetMetadata( - name=( - file_name.split(".")[0].upper() - if "." in file_name - else file_name.upper() - ), - first_record={}, - label="", - modification_date="", - filename=file_name, - full_path=dataset_path, - file_size=0, - record_count=0, - ) - ) + raise InvalidDatasetFormat( + f"Your data file could not be read: {dataset_path}." + ) from e return datasets @staticmethod diff --git a/core.py b/core.py index 5db814132..2ef8d21ef 100644 --- a/core.py +++ b/core.py @@ -23,6 +23,7 @@ from cdisc_rules_engine.enums.default_file_paths import DefaultFilePaths from cdisc_rules_engine.enums.progress_parameter_options import ProgressParameterOptions from cdisc_rules_engine.enums.report_types import ReportTypes +from cdisc_rules_engine.enums.standard_types import StandardTypes from cdisc_rules_engine.models.external_dictionaries_container import ( DictionaryTypes, ExternalDictionariesContainer, @@ -478,6 +479,15 @@ def validate( # noqa if not custom_standard: standard = standard.lower() + supported_standards = StandardTypes.values() + if standard not in supported_standards: + supported_list = ", ".join(sorted(supported_standards)) + logger.error( + f"Standard '{standard}' is not a supported standard. " + f"Supported standards: {supported_list}. " + f"Use --custom-standard flag for custom standards." + ) + ctx.exit(2) if raw_report is True: if not (len(output_format) == 1 and output_format[0] == ReportTypes.JSON.value): diff --git a/scripts/run_validation.py b/scripts/run_validation.py index fe3573b22..3b38028d3 100644 --- a/scripts/run_validation.py +++ b/scripts/run_validation.py @@ -18,6 +18,9 @@ from cdisc_rules_engine.models.sdtm_dataset_metadata import SDTMDatasetMetadata from cdisc_rules_engine.models.validation_args import Validation_args from cdisc_rules_engine.rules_engine import RulesEngine +from cdisc_rules_engine.exceptions.custom_exceptions import ( + LibraryMetadataNotFoundError, +) from cdisc_rules_engine.services import logger as engine_logger from cdisc_rules_engine.services.cache import ( InMemoryCacheService, @@ -40,11 +43,12 @@ set_max_errors_per_rule, ) from scripts.script_utils import ( + library_metadata_not_found_message, fill_cache_with_dictionaries, get_cache_service, get_library_metadata_from_cache, - get_rules, get_max_dataset_size, + get_rules, ) from cdisc_rules_engine.services.reporting import BaseReport, ReportFactory from cdisc_rules_engine.utilities.progress_displayers import get_progress_displayer @@ -123,6 +127,25 @@ def initialize_logger(disabled, log_level): engine_logger.setLevel(log_level) +def _convert_datasets_to_parquet_if_needed( + data_service, datasets, created_files, large_dataset_validation: bool +): + if not (large_dataset_validation and data_service.standard != "usdm"): + return + engine_logger.warning( + "Large datasets must use parquet format, converting all datasets to parquet" + ) + for dataset in datasets: + file_path = dataset.full_path + if file_path.endswith(".parquet"): + continue + num_rows, new_file = data_service.to_parquet(file_path) + created_files.append(new_file) + dataset.full_path = new_file + dataset.record_count = num_rows + dataset.original_path = file_path + + def run_validation(args: Validation_args): set_log_level(args) # fill cache @@ -161,20 +184,12 @@ def run_validation(args: Validation_args): data_service.dataset_implementation != PandasDataset ) datasets = data_service.get_datasets() - if large_dataset_validation and data_service.standard != "usdm": - # convert all files to parquet temp files - engine_logger.warning( - "Large datasets must use parquet format, converting all datasets to parquet" - ) - for dataset in datasets: - file_path = dataset.full_path - if file_path.endswith(".parquet"): - continue - num_rows, new_file = data_service.to_parquet(file_path) - created_files.append(new_file) - dataset.full_path = new_file - dataset.record_count = num_rows - dataset.original_path = file_path + _convert_datasets_to_parquet_if_needed( + data_service, + datasets, + created_files, + large_dataset_validation, + ) engine_logger.info( f"Running {len(rules)} rules against {len(datasets)} datasets" ) @@ -249,6 +264,12 @@ def run_single_rule_validation( standard, standard_version, standard_substandard ) standard_metadata = cache.get(standard_details_cache_key) + if not standard_metadata and standard and standard_version: + raise LibraryMetadataNotFoundError( + library_metadata_not_found_message( + standard, standard_version, standard_substandard + ) + ) if standard_metadata: model_cache_key = get_model_details_cache_key_from_ig(standard_metadata) model_metadata = cache.get(model_cache_key) @@ -259,6 +280,7 @@ def run_single_rule_validation( ) ct_package_metadata = {} + codelists = codelists or [] for codelist in codelists: ct_package_metadata[codelist] = cache.get(codelist) diff --git a/scripts/script_utils.py b/scripts/script_utils.py index b08035d3a..fc5f1d4bd 100644 --- a/scripts/script_utils.py +++ b/scripts/script_utils.py @@ -27,6 +27,10 @@ from cdisc_rules_engine.services.define_xml.define_xml_reader_factory import ( DefineXMLReaderFactory, ) +from cdisc_rules_engine.exceptions.custom_exceptions import ( + CTPackageNotFoundError, + LibraryMetadataNotFoundError, +) def get_library_metadata_from_cache(args) -> LibraryMetadataContainer: # noqa @@ -75,6 +79,14 @@ def get_library_metadata_from_cache(args) -> LibraryMetadataContainer: # noqa data = pickle.load(f) standard_metadata = data.get(standard_details_cache_key, {}) + if not standard_metadata and not args.custom_standard: + if args.standard and args.standard.lower() != "usdm": + raise LibraryMetadataNotFoundError( + library_metadata_not_found_message( + args.standard, args.version, args.substandard + ) + ) + if standard_metadata: model_cache_key = get_model_details_cache_key_from_ig(standard_metadata) with open(models_file, "rb") as f: @@ -104,6 +116,7 @@ def get_library_metadata_from_cache(args) -> LibraryMetadataContainer: # noqa variables_metadata = data.get(cache_key) ct_package_data = {} + define_referenced_ct = set() cache_files = next(os.walk(args.cache), (None, None, []))[2] ct_files = [file_name for file_name in cache_files if "ct-" in file_name] published_ct_packages = set() @@ -124,6 +137,10 @@ def get_library_metadata_from_cache(args) -> LibraryMetadataContainer: # noqa extensible, merged_flag, ) = define_xml_reader.get_ct_standards_metadata() + define_referenced_ct = { + f"{standard.publishing_set.lower()}ct-{standard.version}" + for standard in standards + } for standard in standards: pickle_filename = ( f"{standard.publishing_set.lower()}ct-{standard.version}.pkl" @@ -141,6 +158,16 @@ def get_library_metadata_from_cache(args) -> LibraryMetadataContainer: # noqa if args.define_xml_path: extensible_terms = define_xml_reader.get_extensible_codelist_mappings() ct_package_data["extensible"] = extensible_terms + requested_ct = set(args.controlled_terminology_package or []) | define_referenced_ct + missing_ct = requested_ct - published_ct_packages + if missing_ct: + sorted_missing = sorted( + missing_ct, key=lambda x: (x is None, str(x) if x is not None else "") + ) + raise CTPackageNotFoundError( + "Controlled terminology package(s) not found in cache: " + f"{', '.join(str(c) for c in sorted_missing)}." + ) return LibraryMetadataContainer( standard_metadata=standard_metadata, standard_schema_definition=standard_schema_definition, @@ -585,3 +612,12 @@ def replace_yml_spaces(data): return [replace_yml_spaces(item) for item in data] else: return data + + +def library_metadata_not_found_message(standard, version, substandard=None): + version_display = (version or "").replace("-", ".") + sub_part = f" substandard {substandard}" if substandard else "" + return ( + f"No library metadata found for standard '{standard}' " + f"version '{version_display}'{sub_part}." + ) diff --git a/tests/unit/test_datasets_payload_validation.py b/tests/unit/test_datasets_payload_validation.py new file mode 100644 index 000000000..95e3bf626 --- /dev/null +++ b/tests/unit/test_datasets_payload_validation.py @@ -0,0 +1,168 @@ +""" +Unit tests for datasets payload validation and API error handling. +Covers TestRule Azure function: validate_datasets_payload and handle_exception. +""" + +import json +import importlib +import sys +from unittest.mock import MagicMock + +import pytest +from cdisc_rules_engine.exceptions.custom_exceptions import ( + CTPackageNotFoundError, + LibraryMetadataNotFoundError, +) +from scripts.script_utils import library_metadata_not_found_message + + +class _MockHttpResponse: + def __init__(self, body, status_code=200): + self.status_code = status_code + self._body = body if isinstance(body, bytes) else body.encode("utf-8") + + def get_body(self): + return self._body + + +_mock_func = MagicMock() +_mock_func.HttpResponse = _MockHttpResponse +_mock_azure = MagicMock() +_mock_azure.functions = _mock_func +sys.modules["azure"] = _mock_azure +sys.modules["azure.functions"] = _mock_func + + +def _get_testrule_module(): + return importlib.import_module("TestRule") + + +class TestValidateDatasetsPayload: + """Test validate_datasets_payload raises clear, actionable errors.""" + + def test_missing_required_properties_raises_bad_request_with_datasets_guidance( + self, + ): + testrule = _get_testrule_module() + datasets = [ + { + "filename": "dm.xpt", + "domain": "DM", + "records": {"USUBJID": ["1"]}, + "variables": [{"name": "USUBJID"}], + } + ] + with pytest.raises(testrule.BadRequestError) as exc_info: + testrule.validate_datasets_payload(datasets) + msg = str(exc_info.value) + assert "Test data is incorrect and missing required formatting" in msg + + def test_missing_multiple_required_properties_raises_with_datasets_guidance(self): + testrule = _get_testrule_module() + datasets = [ + { + "filename": "dm.xpt", + } + ] + with pytest.raises(testrule.BadRequestError) as exc_info: + testrule.validate_datasets_payload(datasets) + msg = str(exc_info.value) + assert "Test data is incorrect and missing required formatting" in msg + + def test_valid_payload_passes(self): + testrule = _get_testrule_module() + datasets = [ + { + "filename": "dm.xpt", + "label": "Demographics", + "domain": "DM", + "records": {"USUBJID": ["1"]}, + "variables": [{"name": "USUBJID"}], + } + ] + testrule.validate_datasets_payload(datasets) + + def test_missing_variable_metadata_raises_bad_request(self): + testrule = _get_testrule_module() + datasets = [ + { + "filename": "dm.xpt", + "label": "Demographics", + "domain": "DM", + "records": {"USUBJID": ["1"]}, + "variables": [None], + } + ] + with pytest.raises(testrule.BadRequestError) as exc_info: + testrule.validate_datasets_payload(datasets) + assert "variable metadata" in str(exc_info.value) + + +class TestHandleException: + """Test that handle_exception returns user-friendly JSON for clients.""" + + def test_bad_request_error_returns_400_with_message(self): + testrule = _get_testrule_module() + e = testrule.BadRequestError( + "Test data is incorrect and missing required formatting." + ) + response = testrule.handle_exception(e) + assert response.status_code == 400 + body = json.loads(response.get_body().decode()) + assert body["error"] == "BadRequestError" + assert "message" in body + assert ( + "Test data is incorrect and missing required formatting" in body["message"] + ) + + def test_key_error_for_rule_returns_400_with_bad_request_error_type(self): + testrule = _get_testrule_module() + e = KeyError("'rule' required in request") + response = testrule.handle_exception(e) + assert response.status_code == 400 + body = json.loads(response.get_body().decode()) + assert body["error"] == "BadRequestError" + assert ( + "rule" in body["message"].lower() or "required" in body["message"].lower() + ) + + def test_key_error_for_datasets_returns_400_with_bad_request_error_type(self): + testrule = _get_testrule_module() + e = KeyError("'datasets' required in request") + response = testrule.handle_exception(e) + assert response.status_code == 400 + body = json.loads(response.get_body().decode()) + assert body["error"] == "BadRequestError" + + def test_library_metadata_not_found_error_returns_400_with_message(self): + testrule = _get_testrule_module() + e = LibraryMetadataNotFoundError( + library_metadata_not_found_message("sdtmig", "3-4") + ) + response = testrule.handle_exception(e) + assert response.status_code == 400 + body = json.loads(response.get_body().decode()) + assert body["error"] == "LibraryMetadataNotFoundError" + assert "sdtmig" in body["message"] + assert "3.4" in body["message"] or "version" in body["message"] + + def test_ct_package_not_found_error_returns_400_with_message(self): + testrule = _get_testrule_module() + e = CTPackageNotFoundError( + "Controlled terminology package(s) not found: bad-ct-pkg." + ) + response = testrule.handle_exception(e) + assert response.status_code == 400 + body = json.loads(response.get_body().decode()) + assert body["error"] == "CTPackageNotFoundError" + assert "not found" in body["message"] + assert "bad-ct-pkg" in body["message"] + + def test_other_exception_returns_400_unknown_exception(self): + testrule = _get_testrule_module() + e = ValueError("Something else went wrong") + response = testrule.handle_exception(e) + assert response.status_code == 400 + body = json.loads(response.get_body().decode()) + assert body["error"] == "Unknown Exception" + assert "Something else went wrong" in body["message"] diff --git a/tests/unit/test_services/test_data_service/test_excel_data_service.py b/tests/unit/test_services/test_data_service/test_excel_data_service.py index 6e2e4c90f..a75ca30d5 100644 --- a/tests/unit/test_services/test_data_service/test_excel_data_service.py +++ b/tests/unit/test_services/test_data_service/test_excel_data_service.py @@ -7,7 +7,11 @@ from openpyxl import Workbook from cdisc_rules_engine.config.config import ConfigService +from cdisc_rules_engine.exceptions.custom_exceptions import ExcelTestDataError from cdisc_rules_engine.services.data_services import ExcelDataService +from cdisc_rules_engine.enums.excel_test_sheets import ( + ExcelDataSheets, +) from cdisc_rules_engine.models.dataset import PandasDataset @@ -175,3 +179,87 @@ def test_na_value_preserved_not_converted_to_nan(): finally: # Cleanup temporary file os.unlink(temp_path) + + +def test_get_datasets_missing_datasets_sheet_raises_friendly_error(): + """ + When the workbook has no 'Datasets' sheet (e.g. tab named 'datasets' instead), + get_datasets() raises ExcelTestDataError with message that includes + case-sensitive guidance. + """ + with tempfile.NamedTemporaryFile(suffix=".xlsx", delete=False) as tmp_file: + temp_path = tmp_file.name + + try: + wb = Workbook() + wb.active.title = "datasets" + wb.active.append(["Filename", "Label", "Dataset Name"]) + wb.active.append(["dm.xpt", "Demographics", "DM"]) + wb.create_sheet("dm.xpt") + dm_sheet = wb["dm.xpt"] + dm_sheet.append(["USUBJID", "DOMAIN"]) + dm_sheet.append(["Study ID", "Domain"]) + dm_sheet.append(["Char", "Char"]) + dm_sheet.append(["20", "2"]) + dm_sheet.append(["SUBJ001", "DM"]) + wb.save(temp_path) + wb.close() + + ExcelDataService._instance = None + mock_cache = MagicMock() + mock_cache.get_dataset.return_value = None + + data_service = ExcelDataService( + mock_cache, MagicMock(), MagicMock(), dataset_path=temp_path + ) + + with pytest.raises(ExcelTestDataError) as exc_info: + data_service.get_datasets() + + msg = str(exc_info.value) + assert ExcelDataSheets.DATASETS_SHEET_NAME.value in msg + finally: + os.unlink(temp_path) + + +def test_get_datasets_missing_label_column_raises_friendly_error(): + """ + When the 'Datasets' sheet exists but is missing the 'Label' column, + get_datasets() raises ExcelTestDataError with column names and + case-sensitive guidance. + """ + with tempfile.NamedTemporaryFile(suffix=".xlsx", delete=False) as tmp_file: + temp_path = tmp_file.name + + try: + wb = Workbook() + datasets_sheet = wb.active + datasets_sheet.title = ExcelDataSheets.DATASETS_SHEET_NAME.value + datasets_sheet.append(["Filename", "label", "Dataset Name"]) + datasets_sheet.append(["dm.xpt", "Demographics", "DM"]) + wb.create_sheet("dm.xpt") + dm_sheet = wb["dm.xpt"] + dm_sheet.append(["USUBJID", "DOMAIN"]) + dm_sheet.append(["Study ID", "Domain"]) + dm_sheet.append(["Char", "Char"]) + dm_sheet.append(["20", "2"]) + dm_sheet.append(["SUBJ001", "DM"]) + wb.save(temp_path) + wb.close() + + ExcelDataService._instance = None + mock_cache = MagicMock() + mock_cache.get_dataset.return_value = None + + data_service = ExcelDataService( + mock_cache, MagicMock(), MagicMock(), dataset_path=temp_path + ) + + with pytest.raises(ExcelTestDataError) as exc_info: + data_service.get_datasets() + + msg = str(exc_info.value) + assert "Label" in msg + assert "column" in msg.lower() + finally: + os.unlink(temp_path) diff --git a/tests/unit/test_services/test_data_service/test_local_data_service.py b/tests/unit/test_services/test_data_service/test_local_data_service.py index 82b30744c..4dcfe8eed 100644 --- a/tests/unit/test_services/test_data_service/test_local_data_service.py +++ b/tests/unit/test_services/test_data_service/test_local_data_service.py @@ -3,6 +3,7 @@ import pytest from cdisc_rules_engine.config.config import ConfigService +from cdisc_rules_engine.exceptions.custom_exceptions import InvalidDatasetFormat from cdisc_rules_engine.services.data_services import LocalDataService from cdisc_rules_engine.models.dataset import PandasDataset @@ -85,3 +86,16 @@ def test_get_variables_metdata(dataset_implementation): ] for key in expected_keys: assert key in data + + +def test_get_datasets_raises_invalid_dataset_format_when_file_cannot_be_read(): + """get_datasets() raises InvalidDatasetFormat with user-friendly message when a file cannot be read.""" + mock_cache = MagicMock() + mock_cache.get_dataset.return_value = None + data_service = LocalDataService( + mock_cache, MagicMock(), MagicMock(), dataset_paths=["/bad/path.xpt"] + ) + with pytest.raises(InvalidDatasetFormat) as exc_info: + data_service.get_datasets() + assert "Your data file could not be read" in str(exc_info.value) + assert "/bad/path.xpt" in str(exc_info.value)