diff --git a/cdisc_rules_engine/check_operators/dataframe_operators.py b/cdisc_rules_engine/check_operators/dataframe_operators.py index 9faea2be4..355562c8a 100644 --- a/cdisc_rules_engine/check_operators/dataframe_operators.py +++ b/cdisc_rules_engine/check_operators/dataframe_operators.py @@ -1,5 +1,5 @@ from business_rules.operators import BaseType, type_operator -from typing import Union, Any, List, Tuple +from typing import Union, Any, List, Tuple, Sequence from business_rules.fields import FIELD_DATAFRAME from cdisc_rules_engine.check_operators.helpers import ( flatten_list, @@ -130,6 +130,21 @@ def replace_all_prefixes(self, values: List[str]) -> List[str]: values[i] = self.replace_prefix(values[i]) return values + def _normalize_grouping_columns( + self, within: Union[str, Sequence[str]] + ) -> List[str]: + if within is None: + raise ValueError("within parameter is required") + if isinstance(within, (list, tuple)): + columns = [self.replace_prefix(column) for column in within] + else: + columns = [self.replace_prefix(within)] + if not columns or any( + not isinstance(column, str) or not column for column in columns + ): + raise ValueError("within must contain valid column names") + return list(dict.fromkeys(columns)) + def get_comparator_data(self, comparator, value_is_literal: bool = False): if value_is_literal: return comparator @@ -1614,40 +1629,46 @@ def target_is_sorted_by(self, other_value: dict): Checking the sort order based on comparators, including date overlap checks """ target: str = self.replace_prefix(other_value.get("target")) - within: str = self.replace_prefix(other_value.get("within")) + within_columns = self._normalize_grouping_columns(other_value.get("within")) columns = other_value["comparator"] result = pd.Series([True] * len(self.value), index=self.value.index) - pandas = isinstance(self.value, PandasDataset) + is_pandas_dataset = isinstance(self.value, PandasDataset) for col in columns: comparator: str = self.replace_prefix(col["name"]) ascending: bool = col["sort_order"].lower() != "desc" na_pos: str = col["null_position"] - sorted_df = self.value[[target, within, comparator]].sort_values( - by=[within, comparator], ascending=ascending, na_position=na_pos + selected_columns = list( + dict.fromkeys([target, comparator, *within_columns]) ) - grouped_df = sorted_df.groupby(within) - - # Check basic sort order, remove multiindex from series + sorted_df = self.value[selected_columns].sort_values( + by=[*within_columns, comparator], + ascending=ascending, + na_position=na_pos, + ) + grouped_df = sorted_df.groupby(within_columns) basic_sort_check = grouped_df.apply( lambda x: self.check_basic_sort_order(x, target, comparator, ascending) ) - if pandas: - basic_sort_check = basic_sort_check.reset_index(level=0, drop=True) + if is_pandas_dataset and isinstance(basic_sort_check.index, pd.MultiIndex): + basic_sort_check = basic_sort_check.droplevel( + list(range(len(within_columns))) + ) else: basic_sort_check = basic_sort_check.reset_index(drop=True) - result = result & basic_sort_check - # Check date overlaps, remove multiindex from series date_overlap_check = grouped_df.apply( lambda x: self.check_date_overlaps(x, target, comparator) ) - if pandas: - date_overlap_check = date_overlap_check.reset_index(level=0, drop=True) + if is_pandas_dataset and isinstance( + date_overlap_check.index, pd.MultiIndex + ): + date_overlap_check = date_overlap_check.droplevel( + list(range(len(within_columns))) + ) else: date_overlap_check = date_overlap_check.reset_index(drop=True) - result = result & date_overlap_check + result = result & basic_sort_check & date_overlap_check - # handle edge case where a dataframe is returned if isinstance(result, (pd.DataFrame, dd.DataFrame)): if isinstance(result, dd.DataFrame): result = result.compute() diff --git a/resources/schema/Operator.json b/resources/schema/Operator.json index 8c0844690..1efd47518 100644 --- a/resources/schema/Operator.json +++ b/resources/schema/Operator.json @@ -591,7 +591,16 @@ "value_is_reference": { "type": "boolean" }, "type_insensitive": { "type": "boolean" }, "round_values": { "type": "boolean" }, - "within": { "$ref": "CORE-base.json#/$defs/VariableName" }, + "within": { + "oneOf": [ + { "$ref": "CORE-base.json#/$defs/VariableName" }, + { + "items": { "$ref": "CORE-base.json#/$defs/VariableName" }, + "minItems": 1, + "type": "array" + } + ] + }, "regex": { "type": "string" } }, "required": ["operator"], diff --git a/resources/schema/Operator.md b/resources/schema/Operator.md index 3810a2841..eecf2ebd6 100644 --- a/resources/schema/Operator.md +++ b/resources/schema/Operator.md @@ -1017,13 +1017,15 @@ Complement of `is_ordered_by` ### target_is_sorted_by -True if the values in `name` are ordered according to the values specified by `value` grouped by the values in `within`. Each `value` requires a variable `name`, ordering specified by `order`, and the null position specified by `null_position`. +True if the values in `name` are ordered according to the values specified by `value` grouped by the values in `within`. Each `value` requires a variable `name`, ordering specified by `order`, and the null position specified by `null_position`. `within` accepts either a single column or an ordered list of columns. ```yaml Check: all: - name: --SEQ - within: USUBJID + within: + - USUBJID + - MIDSTYPE operator: target_is_sorted_by value: - name: --STDTC diff --git a/tests/unit/test_check_operators/test_relationship_integrity_checks.py b/tests/unit/test_check_operators/test_relationship_integrity_checks.py index 53e3326a5..943c93b54 100644 --- a/tests/unit/test_check_operators/test_relationship_integrity_checks.py +++ b/tests/unit/test_check_operators/test_relationship_integrity_checks.py @@ -652,6 +652,70 @@ def test_target_is_sorted_by(dataset_class): ) ) + +@pytest.mark.parametrize("dataset_class", [PandasDataset, DaskDataset]) +def test_target_is_sorted_by_multiple_within(dataset_class): + usubjid = ["CDISC001", "CDISC001", "CDISC001", "CDISC001", "CDISC002", "CDISC002"] + midstype = ["A", "A", "B", "B", "A", "A"] + mids = ["A1", "A2", "B1", "B2", "A1", "A2"] + smstdtc = [ + "2006-06-01", + "2006-06-02", + "2006-06-03", + "2006-06-04", + "2007-01-01", + "2007-01-02", + ] + data = { + "USUBJID": usubjid, + "MIDSTYPE": midstype, + "MIDS": mids, + "SMSTDTC": smstdtc, + } + df = dataset_class.from_dict(data) + other_value = { + "target": "MIDS", + "within": ["USUBJID", "MIDSTYPE"], + "comparator": [ + {"name": "SMSTDTC", "sort_order": "ASC", "null_position": "last"} + ], + } + expected = [True] * len(usubjid) + result = DataframeType({"value": df}).target_is_sorted_by(other_value) + assert result.equals(df.convert_to_series(expected)) + + +@pytest.mark.parametrize("dataset_class", [PandasDataset, DaskDataset]) +def test_target_is_sorted_by_multiple_within_not_sorted(dataset_class): + usubjid = ["CDISC001", "CDISC001", "CDISC001", "CDISC001", "CDISC002", "CDISC002"] + midstype = ["A", "A", "B", "B", "A", "A"] + mids = ["A2", "A1", "B1", "B2", "A1", "A2"] + smstdtc = [ + "2006-06-01", + "2006-06-02", + "2006-06-03", + "2006-06-04", + "2007-01-01", + "2007-01-02", + ] + data = { + "USUBJID": usubjid, + "MIDSTYPE": midstype, + "MIDS": mids, + "SMSTDTC": smstdtc, + } + df = dataset_class.from_dict(data) + other_value = { + "target": "MIDS", + "within": ["USUBJID", "MIDSTYPE"], + "comparator": [ + {"name": "SMSTDTC", "sort_order": "ASC", "null_position": "last"} + ], + } + expected = [False, False, True, True, True, True] + result = DataframeType({"value": df}).target_is_sorted_by(other_value) + assert result.equals(df.convert_to_series(expected)) + valid_desc_df = dataset_class.from_dict( { "USUBJID": ["CDISC001", "CDISC002", "CDISC002", "CDISC001", "CDISC001"],