diff --git a/cdisc_rules_engine/check_operators/dataframe_operators.py b/cdisc_rules_engine/check_operators/dataframe_operators.py index 1cf37ede4..d90bb8a55 100644 --- a/cdisc_rules_engine/check_operators/dataframe_operators.py +++ b/cdisc_rules_engine/check_operators/dataframe_operators.py @@ -82,6 +82,8 @@ def __init__(self, data): self.codelist_term_maps = data.get("codelist_term_maps", []) def _assert_valid_value_and_cast(self, value): + if isinstance(value, dict): + value = self._resolve_prefixes(value) return value def _regex_str_conversion(self, x): @@ -141,15 +143,21 @@ def replace_all_prefixes(self, values: List[str]) -> List[str]: values[i] = self.replace_prefix(values[i]) return values + def _resolve_prefixes(self, other_value: dict) -> dict: + other_value = other_value.copy() + for key, value in other_value.items(): + if isinstance(value, str): + other_value[key] = self.replace_prefix(value) + elif isinstance(value, list): + other_value[key] = self.replace_all_prefixes(value) + return other_value + def _normalize_grouping_columns( self, within: Union[str, Sequence[str]] ) -> List[str]: if within is None: raise ValueError("within parameter is required") - if isinstance(within, (list, tuple)): - columns = [self.replace_prefix(column) for column in within] - else: - columns = [self.replace_prefix(within)] + columns = list(within) if isinstance(within, (list, tuple)) else [within] if not columns or any( not isinstance(column, str) or not column for column in columns ): @@ -174,7 +182,7 @@ def is_column_of_iterables(self, column): @log_operator_execution @type_operator(FIELD_DATAFRAME) def exists(self, other_value): - target_column = self.replace_prefix(other_value.get("target")) + target_column = other_value.get("target") def check_row(row): return any(target_column in item for item in row if isinstance(item, list)) @@ -288,16 +296,12 @@ def _check_inequality( @log_operator_execution @type_operator(FIELD_DATAFRAME) def equal_to(self, other_value): - target = self.replace_prefix(other_value.get("target")) + target = other_value.get("target") value_is_literal = other_value.get("value_is_literal", False) value_is_reference = other_value.get("value_is_reference", False) type_insensitive = other_value.get("type_insensitive", False) round_values = other_value.get("round_values", False) - comparator = ( - self.replace_prefix(other_value.get("comparator")) - if not value_is_literal - else other_value.get("comparator") - ) + comparator = other_value.get("comparator") return self.value.apply( lambda row: self._check_equality( row, @@ -315,16 +319,13 @@ def equal_to(self, other_value): @log_operator_execution @type_operator(FIELD_DATAFRAME) def equal_to_case_insensitive(self, other_value): - target = self.replace_prefix(other_value.get("target")) + target = other_value.get("target") value_is_literal = other_value.get("value_is_literal", False) value_is_reference = other_value.get("value_is_reference", False) type_insensitive = other_value.get("type_insensitive", False) round_values = other_value.get("round_values", False) - comparator = ( - self.replace_prefix(other_value.get("comparator")) - if not value_is_literal - else other_value.get("comparator") - ) + comparator = other_value.get("comparator") + return self.value.apply( lambda row: self._check_equality( row, @@ -343,16 +344,13 @@ def equal_to_case_insensitive(self, other_value): @log_operator_execution @type_operator(FIELD_DATAFRAME) def not_equal_to_case_insensitive(self, other_value): - target = self.replace_prefix(other_value.get("target")) + target = other_value.get("target") value_is_literal = other_value.get("value_is_literal", False) value_is_reference = other_value.get("value_is_reference", False) type_insensitive = other_value.get("type_insensitive", False) round_values = other_value.get("round_values", False) - comparator = ( - self.replace_prefix(other_value.get("comparator")) - if not value_is_literal - else other_value.get("comparator") - ) + comparator = other_value.get("comparator") + return self.value.apply( lambda row: self._check_inequality( row, @@ -370,16 +368,13 @@ def not_equal_to_case_insensitive(self, other_value): @log_operator_execution @type_operator(FIELD_DATAFRAME) def not_equal_to(self, other_value): - target = self.replace_prefix(other_value.get("target")) + target = other_value.get("target") value_is_literal = other_value.get("value_is_literal", False) value_is_reference = other_value.get("value_is_reference", False) type_insensitive = other_value.get("type_insensitive", False) round_values = other_value.get("round_values", False) - comparator = ( - self.replace_prefix(other_value.get("comparator")) - if not value_is_literal - else other_value.get("comparator") - ) + comparator = other_value.get("comparator") + return self.value.apply( lambda row: self._check_inequality( row, @@ -400,15 +395,11 @@ def suffix_equal_to(self, other_value: dict): """ Checks if target suffix is equal to comparator. """ - target: str = self.replace_prefix(other_value.get("target")) + target: str = other_value.get("target") value_is_literal: bool = other_value.get("value_is_literal", False) - comparator: Union[str, Any] = ( - self.replace_prefix(other_value.get("comparator")) - if not value_is_literal - else other_value.get("comparator") - ) + comparator: Union[str, Any] = other_value.get("comparator") comparison_data = self.get_comparator_data(comparator, value_is_literal) - suffix: int = self.replace_prefix(other_value.get("suffix")) + suffix: int = other_value.get("suffix") return self._check_equality_of_string_part( target, comparison_data, "suffix", suffix ) @@ -427,18 +418,14 @@ def prefix_equal_to(self, other_value: dict): """ Checks if target prefix is equal to comparator. """ - target: str = self.replace_prefix(other_value.get("target")) + target: str = other_value.get("target") value_is_literal: bool = other_value.get("value_is_literal", False) - comparator: Union[str, Any] = ( - self.replace_prefix(other_value.get("comparator")) - if not value_is_literal - else other_value.get("comparator") - ) + comparator: Union[str, Any] = other_value.get("comparator") if comparator == "DOMAIN": comparison_data = self.column_prefix_map["--"] else: comparison_data = self.get_comparator_data(comparator, value_is_literal) - prefix: int = self.replace_prefix(other_value.get("prefix")) + prefix: int = other_value.get("prefix") return self._check_equality_of_string_part( target, comparison_data, "prefix", prefix ) @@ -457,13 +444,9 @@ def prefix_is_contained_by(self, other_value: dict): """ Checks if target prefix is contained by the comparator. """ - target: str = self.replace_prefix(other_value.get("target")) + target: str = other_value.get("target") value_is_literal: bool = other_value.get("value_is_literal", False) - comparator: Union[str, Any] = ( - self.replace_prefix(other_value.get("comparator")) - if not value_is_literal - else other_value.get("comparator") - ) + comparator: Union[str, Any] = other_value.get("comparator") comparison_data = self.get_comparator_data(comparator, value_is_literal) prefix_length: int = other_value.get("prefix") series_to_validate = self._get_string_part_series( @@ -482,13 +465,9 @@ def suffix_is_contained_by(self, other_value: dict): """ Checks if target prefix is equal to comparator. """ - target: str = self.replace_prefix(other_value.get("target")) + target: str = other_value.get("target") value_is_literal: bool = other_value.get("value_is_literal", False) - comparator: Union[str, Any] = ( - self.replace_prefix(other_value.get("comparator")) - if not value_is_literal - else other_value.get("comparator") - ) + comparator: Union[str, Any] = other_value.get("comparator") comparison_data = self.get_comparator_data(comparator, value_is_literal) suffix_length: int = other_value.get("suffix") series_to_validate = self._get_string_part_series( @@ -557,13 +536,9 @@ def _to_numeric(self, target, **kwargs): @log_operator_execution @type_operator(FIELD_DATAFRAME) def less_than(self, other_value): - target = self.replace_prefix(other_value.get("target")) + target = other_value.get("target") value_is_literal = other_value.get("value_is_literal", False) - comparator = ( - self.replace_prefix(other_value.get("comparator")) - if not value_is_literal - else other_value.get("comparator") - ) + comparator = other_value.get("comparator") comparison_data = self.get_comparator_data(comparator, value_is_literal) target_column = self._to_numeric(self.value[target], errors="coerce") if self.value.is_series(comparison_data): @@ -574,13 +549,9 @@ def less_than(self, other_value): @log_operator_execution @type_operator(FIELD_DATAFRAME) def less_than_or_equal_to(self, other_value): - target = self.replace_prefix(other_value.get("target")) + target = other_value.get("target") value_is_literal = other_value.get("value_is_literal", False) - comparator = ( - self.replace_prefix(other_value.get("comparator")) - if not value_is_literal - else other_value.get("comparator") - ) + comparator = other_value.get("comparator") comparison_data = self.get_comparator_data(comparator, value_is_literal) target_column = self._to_numeric(self.value[target], errors="coerce") if self.value.is_series(comparison_data): @@ -591,13 +562,9 @@ def less_than_or_equal_to(self, other_value): @log_operator_execution @type_operator(FIELD_DATAFRAME) def greater_than_or_equal_to(self, other_value): - target = self.replace_prefix(other_value.get("target")) + target = other_value.get("target") value_is_literal = other_value.get("value_is_literal", False) - comparator = ( - self.replace_prefix(other_value.get("comparator")) - if not value_is_literal - else other_value.get("comparator") - ) + comparator = other_value.get("comparator") comparison_data = self.get_comparator_data(comparator, value_is_literal) target_column = self._to_numeric(self.value[target], errors="coerce") if self.value.is_series(comparison_data): @@ -608,13 +575,9 @@ def greater_than_or_equal_to(self, other_value): @log_operator_execution @type_operator(FIELD_DATAFRAME) def greater_than(self, other_value): - target = self.replace_prefix(other_value.get("target")) + target = other_value.get("target") value_is_literal = other_value.get("value_is_literal", False) - comparator = ( - self.replace_prefix(other_value.get("comparator")) - if not value_is_literal - else other_value.get("comparator") - ) + comparator = other_value.get("comparator") comparison_data = self.get_comparator_data(comparator, value_is_literal) target_column = self._to_numeric(self.value[target], errors="coerce") if self.value.is_series(comparison_data): @@ -625,13 +588,9 @@ def greater_than(self, other_value): @log_operator_execution @type_operator(FIELD_DATAFRAME) def contains(self, other_value): - target = self.replace_prefix(other_value.get("target")) + target = other_value.get("target") value_is_literal = other_value.get("value_is_literal", False) - comparator = ( - self.replace_prefix(other_value.get("comparator")) - if not value_is_literal - else other_value.get("comparator") - ) + comparator = other_value.get("comparator") comparison_data = self.get_comparator_data(comparator, value_is_literal) if self.is_column_of_iterables(self.value[target]) or isinstance( comparison_data, str @@ -655,13 +614,9 @@ def _series_is_in(self, target, comparison_data): @log_operator_execution @type_operator(FIELD_DATAFRAME) def contains_case_insensitive(self, other_value): - target = self.replace_prefix(other_value.get("target")) + target = other_value.get("target") value_is_literal = other_value.get("value_is_literal", False) - comparator = ( - self.replace_prefix(other_value.get("comparator")) - if not value_is_literal - else other_value.get("comparator") - ) + comparator = other_value.get("comparator") comparison_data = self.get_comparator_data(comparator, value_is_literal) comparison_data = self.convert_string_data_to_lower(comparison_data) if self.is_column_of_iterables(self.value[target]): @@ -687,12 +642,9 @@ def does_not_contain_case_insensitive(self, other_value): @log_operator_execution @type_operator(FIELD_DATAFRAME) def is_contained_by(self, other_value): - target = self.replace_prefix(other_value.get("target")) + target = other_value.get("target") value_is_literal = other_value.get("value_is_literal", False) comparator = other_value.get("comparator") - if isinstance(comparator, str) and not value_is_literal: - # column name provided - comparator = self.replace_prefix(comparator) comparison_data = self.get_comparator_data(comparator, value_is_literal) target_data = self.value[target] if self.is_column_of_iterables(target_data): @@ -730,14 +682,11 @@ def is_not_contained_by(self, other_value): @log_operator_execution @type_operator(FIELD_DATAFRAME) def is_contained_by_case_insensitive(self, other_value): - target = self.replace_prefix(other_value.get("target")) + target = other_value.get("target") comparator = other_value.get("comparator", []) value_is_literal = other_value.get("value_is_literal", False) if isinstance(comparator, list): comparator = [val.lower() for val in comparator] - elif isinstance(comparator, str) and not value_is_literal: - # column name provided - comparator = self.replace_prefix(comparator) comparison_data = self.get_comparator_data(comparator, value_is_literal) if self.is_column_of_iterables(comparison_data): results = vectorized_case_insensitive_is_in( @@ -758,7 +707,7 @@ def is_not_contained_by_case_insensitive(self, other_value): @log_operator_execution @type_operator(FIELD_DATAFRAME) def prefix_matches_regex(self, other_value): - target = self.replace_prefix(other_value.get("target")) + target = other_value.get("target") comparator = other_value.get("comparator") prefix = other_value.get("prefix") converted_strings = self.value[target].map( @@ -772,7 +721,7 @@ def prefix_matches_regex(self, other_value): @log_operator_execution @type_operator(FIELD_DATAFRAME) def not_prefix_matches_regex(self, other_value): - target = self.replace_prefix(other_value.get("target")) + target = other_value.get("target") comparator = other_value.get("comparator") prefix = other_value.get("prefix") converted_strings = self.value[target].map( @@ -786,7 +735,7 @@ def not_prefix_matches_regex(self, other_value): @log_operator_execution @type_operator(FIELD_DATAFRAME) def suffix_matches_regex(self, other_value): - target = self.replace_prefix(other_value.get("target")) + target = other_value.get("target") comparator = other_value.get("comparator") suffix = other_value.get("suffix") converted_strings = self.value[target].map( @@ -800,7 +749,7 @@ def suffix_matches_regex(self, other_value): @log_operator_execution @type_operator(FIELD_DATAFRAME) def not_suffix_matches_regex(self, other_value): - target = self.replace_prefix(other_value.get("target")) + target = other_value.get("target") comparator = other_value.get("comparator") suffix = other_value.get("suffix") converted_strings = self.value[target].map( @@ -814,7 +763,7 @@ def not_suffix_matches_regex(self, other_value): @log_operator_execution @type_operator(FIELD_DATAFRAME) def matches_regex(self, other_value): - target = self.replace_prefix(other_value.get("target")) + target = other_value.get("target") comparator = other_value.get("comparator") converted_strings = self.value[target].map( lambda x: self._regex_str_conversion(x) @@ -827,7 +776,7 @@ def matches_regex(self, other_value): @log_operator_execution @type_operator(FIELD_DATAFRAME) def not_matches_regex(self, other_value): - target = self.replace_prefix(other_value.get("target")) + target = other_value.get("target") comparator = other_value.get("comparator") converted_strings = self.value[target].map( lambda x: self._regex_str_conversion(x) @@ -845,7 +794,7 @@ def equals_string_part(self, other_value): equal the result of parsing the value in the comparison column with a regex """ - target = self.replace_prefix(other_value.get("target")) + target = other_value.get("target") comparator = other_value.get("comparator") regex = other_value.get("regex") value_is_literal: bool = other_value.get("value_is_literal", False) @@ -869,7 +818,7 @@ def does_not_equal_string_part(self, other_value): @log_operator_execution @type_operator(FIELD_DATAFRAME) def starts_with(self, other_value): - target = self.replace_prefix(other_value.get("target")) + target = other_value.get("target") comparator = other_value.get("comparator") value_is_literal: bool = other_value.get("value_is_literal", False) comparison_data = self.get_comparator_data(comparator, value_is_literal) @@ -882,7 +831,7 @@ def starts_with(self, other_value): @log_operator_execution @type_operator(FIELD_DATAFRAME) def ends_with(self, other_value): - target = self.replace_prefix(other_value.get("target")) + target = other_value.get("target") comparator = other_value.get("comparator") value_is_literal: bool = other_value.get("value_is_literal", False) comparison_data = self.get_comparator_data(comparator, value_is_literal) @@ -900,7 +849,7 @@ def has_equal_length(self, other_value: dict): If comparing two columns (value_is_literal is False), the operator compares lengths of values in these columns. """ - target = self.replace_prefix(other_value.get("target")) + target = other_value.get("target") comparator = other_value.get("comparator") value_is_literal: bool = other_value.get("value_is_literal", False) comparison_data = self.get_comparator_data(comparator, value_is_literal) @@ -931,7 +880,7 @@ def longer_than(self, other_value: dict): If comparing two columns (value_is_literal is False), the operator compares lengths of values in these columns. """ - target = self.replace_prefix(other_value.get("target")) + target = other_value.get("target") comparator = other_value.get("comparator") value_is_literal: bool = other_value.get("value_is_literal", False) comparison_data = self.get_comparator_data(comparator, value_is_literal) @@ -947,7 +896,7 @@ def longer_than(self, other_value: dict): @log_operator_execution @type_operator(FIELD_DATAFRAME) def longer_than_or_equal_to(self, other_value: dict): - target = self.replace_prefix(other_value.get("target")) + target = other_value.get("target") comparator = other_value.get("comparator") value_is_literal: bool = other_value.get("value_is_literal", False) comparison_data = self.get_comparator_data(comparator, value_is_literal) @@ -976,7 +925,7 @@ def split_parts_have_equal_length(self, other_value: dict): """ Splits string values by a separator and checks if both parts have equal length. """ - target = self.replace_prefix(other_value.get("target")) + target = other_value.get("target") separator = other_value.get("separator", "/") target_series = self.value[target] @@ -1009,7 +958,7 @@ def split_parts_have_unequal_length(self, other_value: dict): @log_operator_execution @type_operator(FIELD_DATAFRAME) def empty(self, other_value: dict): - target = self.replace_prefix(other_value.get("target")) + target = other_value.get("target") series = self.value[target] def check_empty(x): @@ -1031,9 +980,9 @@ def check_empty(x): @log_operator_execution @type_operator(FIELD_DATAFRAME) def empty_within_except_last_row(self, other_value: dict): - target = self.replace_prefix(other_value.get("target")) + target = other_value.get("target") comparator = other_value.get("comparator") - order_by_column: str = self.replace_prefix(other_value.get("ordering")) + order_by_column: str = other_value.get("ordering") # group all targets by comparator if order_by_column: ordered_df = self.value.sort_values(by=[comparator, order_by_column]) @@ -1061,9 +1010,9 @@ def non_empty(self, other_value: dict): @log_operator_execution @type_operator(FIELD_DATAFRAME) def non_empty_within_except_last_row(self, other_value: dict): - target = self.replace_prefix(other_value.get("target")) + target = other_value.get("target") comparator = other_value.get("comparator") - order_by_column: str = self.replace_prefix(other_value.get("ordering")) + order_by_column: str = other_value.get("ordering") # group all targets by comparator if order_by_column: ordered_df = self.value.sort_values(by=[comparator, order_by_column]) @@ -1088,7 +1037,7 @@ def non_empty_within_except_last_row(self, other_value: dict): @log_operator_execution @type_operator(FIELD_DATAFRAME) def contains_all(self, other_value: dict): - target = self.replace_prefix(other_value.get("target")) + target = other_value.get("target") value_is_literal: bool = other_value.get("value_is_literal", False) comparator = other_value.get("comparator") if self.is_column_of_iterables( @@ -1105,7 +1054,6 @@ def contains_all(self, other_value: dict): # get column as array of values values = flatten_list(self.value, comparator) else: - comparator = self.replace_prefix(comparator) values = self.value[comparator].unique() results = set(values).issubset(set(self.value[target].unique())) return self.value.convert_to_series(results) @@ -1125,7 +1073,7 @@ def invalid_date(self, other_value): @log_operator_execution @type_operator(FIELD_DATAFRAME) def invalid_duration(self, other_value): - target = self.replace_prefix(other_value.get("target")) + target = other_value.get("target") if other_value.get("negative") is False: results = ~vectorized_is_valid_duration(self.value[target], False) else: @@ -1133,8 +1081,8 @@ def invalid_duration(self, other_value): return self.value.convert_to_series(results) def date_comparison(self, other_value, operator): - target = self.replace_prefix(other_value.get("target")) - comparator = self.replace_prefix(other_value.get("comparator")) + target = other_value.get("target") + comparator = other_value.get("comparator") value_is_literal: bool = other_value.get("value_is_literal", False) comparison_data = self.get_comparator_data(comparator, value_is_literal) component = other_value.get("date_component") @@ -1192,18 +1140,16 @@ def is_complete_date(self, other_value): @log_operator_execution @type_operator(FIELD_DATAFRAME) def is_inconsistent_across_dataset(self, other_value): - target = self.replace_prefix(other_value.get("target")) + target = other_value.get("target") comparator = other_value.get("comparator") grouping_cols = [] if isinstance(comparator, str): - col_name = self.replace_prefix(comparator) - if col_name in self.value.columns: - grouping_cols.append(col_name) + if comparator in self.value.columns: + grouping_cols.append(comparator) else: for col in comparator: - col_name = self.replace_prefix(col) - if col_name in self.value.columns: - grouping_cols.append(col_name) + if col in self.value.columns: + grouping_cols.append(col) df_check = self.value[grouping_cols + [target]].copy() df_check = df_check.fillna("_NaN_") results = pd.Series(False, index=df_check.index) @@ -1224,14 +1170,13 @@ def is_inconsistent_across_dataset(self, other_value): @log_operator_execution @type_operator(FIELD_DATAFRAME) def is_unique_set(self, other_value): - target = self.replace_prefix(other_value.get("target")) + target = other_value.get("target") comparator = other_value.get("comparator") regex_pattern = other_value.get("regex") values = [target, comparator] target_data = flatten_list(self.value, values) target_names = [] for target_name in target_data: - target_name = self.replace_prefix(target_name) if target_name in self.value.columns: target_names.append(target_name) target_names = list(set(target_names)) @@ -1276,13 +1221,11 @@ def is_not_unique_relationship(self, other_value): A violation occurs when a NON-NULL value in either column maps to multiple different values in the other column. """ - target = self.replace_prefix(other_value.get("target")) + target = other_value.get("target") comparator = other_value.get("comparator") if isinstance(comparator, list): - comparator = self.replace_all_prefixes(comparator) columns = [target] + comparator else: - comparator = self.replace_prefix(comparator) columns = [target, comparator] df_subset = self.value[columns].dropna(how="all") @@ -1405,7 +1348,7 @@ def is_unique_relationship(self, other_value): @log_operator_execution @type_operator(FIELD_DATAFRAME) def is_ordered_set(self, other_value): - target = self.replace_prefix(other_value.get("target")) + target = other_value.get("target") value = other_value.get("comparator") if not isinstance(value, str): raise Exception("Comparator must be a single String value") @@ -1468,10 +1411,10 @@ def has_next_corresponding_record(self, other_value: dict): and first row from comparator and compare the resulting contents. The result is reported for target. """ - target = self.replace_prefix(other_value.get("target")) - comparator = self.replace_prefix(other_value.get("comparator")) - group_by_column: str = self.replace_prefix(other_value.get("within")) - order_by_column: str = self.replace_prefix(other_value.get("ordering")) + target = other_value.get("target") + comparator = other_value.get("comparator") + group_by_column: str = other_value.get("within") + order_by_column: str = other_value.get("ordering") target_columns = [target, comparator, group_by_column, order_by_column] ordered_df = self.value[target_columns].sort_values(by=[order_by_column]) grouped_df = ordered_df.groupby(group_by_column) @@ -1520,9 +1463,9 @@ def present_on_multiple_rows_within(self, other_value: dict): within a group_by column. The dataframe is grouped by a certain column and the check is applied to each group. """ - target = self.replace_prefix(other_value.get("target")) + target = other_value.get("target") min_count: int = other_value.get("comparator") or 1 - group_by_column = self.replace_prefix(other_value.get("within")) + group_by_column = other_value.get("within") grouped = self.value.groupby([group_by_column, target]) meta = (target, bool) results = grouped.apply( @@ -1555,7 +1498,7 @@ def inconsistent_enumerated_columns(self, other_value: dict): Note that the initial variable will not have an index (VARIABLE) and the next enumerated variable has index 1 (VARIABLE1). """ - variable_name: str = self.replace_prefix(other_value.get("target")) + variable_name: str = other_value.get("target") df = self.value pattern = rf"^{re.escape(variable_name)}(\d*)$" matching_columns = [col for col in df.columns if re.match(pattern, col)] @@ -1582,8 +1525,8 @@ def check_inconsistency(row): @log_operator_execution @type_operator(FIELD_DATAFRAME) def references_correct_codelist(self, other_value: dict): - target: str = self.replace_prefix(other_value.get("target")) - comparator = self.replace_prefix(other_value.get("comparator")) + target = other_value.get("target") + comparator = other_value.get("comparator") result = self.value.apply( lambda row: self.valid_codelist_reference(row[target], row[comparator]), axis=1, @@ -1628,7 +1571,7 @@ def has_different_values(self, other_value: dict): """ The operator ensures that the target column has different values. """ - target: str = self.replace_prefix(other_value.get("target")) + target = other_value.get("target") is_valid: bool = len(self.value[target].unique()) > 1 return self.value.convert_to_series([is_valid] * len(self.value[target])) @@ -1643,7 +1586,7 @@ def is_ordered_by(self, other_value: dict): """ Checking validity based on target order. """ - target: str = self.replace_prefix(other_value.get("target")) + target = other_value.get("target") sort_order: str = other_value.get("order", "asc") if sort_order not in ["asc", "dsc"]: raise ValueError("invalid sorting order") @@ -1670,8 +1613,8 @@ def value_has_multiple_references(self, other_value: dict): Requires a target column and a reference count column whose values are a dictionary containing the number of times that value appears. """ - target: str = self.replace_prefix(other_value.get("target")) - reference_count_column: str = self.replace_prefix(other_value.get("comparator")) + target = other_value.get("target") + reference_count_column: str = other_value.get("comparator") result = np.where( vectorized_get_dict_key( self.value[reference_count_column], self.value[target] @@ -1769,7 +1712,7 @@ def target_is_sorted_by(self, other_value: dict): """ Checking the sort order based on comparators, including date overlap checks """ - target: str = self.replace_prefix(other_value.get("target")) + target = other_value.get("target") within_columns = self._normalize_grouping_columns(other_value.get("within")) columns = other_value["comparator"] result = pd.Series([True] * len(self.value), index=self.value.index) @@ -1828,8 +1771,8 @@ def target_is_not_sorted_by(self, other_value: dict): @log_operator_execution @type_operator(FIELD_DATAFRAME) def shares_at_least_one_element_with(self, other_value: dict): - target: str = self.replace_prefix(other_value.get("target")) - comparator: str = self.replace_prefix(other_value.get("comparator")) + target: str = other_value.get("target") + comparator: str = other_value.get("comparator") def check_shared_elements(row): target_set = ( @@ -1849,8 +1792,8 @@ def check_shared_elements(row): @log_operator_execution @type_operator(FIELD_DATAFRAME) def shares_exactly_one_element_with(self, other_value: dict): - target: str = self.replace_prefix(other_value.get("target")) - comparator: str = self.replace_prefix(other_value.get("comparator")) + target: str = other_value.get("target") + comparator: str = other_value.get("comparator") def check_exactly_one_shared_element(row): target_set = ( @@ -1870,8 +1813,8 @@ def check_exactly_one_shared_element(row): @log_operator_execution @type_operator(FIELD_DATAFRAME) def shares_no_elements_with(self, other_value: dict): - target: str = self.replace_prefix(other_value.get("target")) - comparator: str = self.replace_prefix(other_value.get("comparator")) + target: str = other_value.get("target") + comparator: str = other_value.get("comparator") def check_no_shared_elements(row): target_set = ( @@ -1891,8 +1834,8 @@ def check_no_shared_elements(row): @log_operator_execution @type_operator(FIELD_DATAFRAME) def is_ordered_subset_of(self, other_value: dict): - target = self.replace_prefix(other_value.get("target")) - comparator = self.replace_prefix(other_value.get("comparator")) + target: str = other_value.get("target") + comparator: str = other_value.get("comparator") missing_columns = set() def check_order(row): diff --git a/cdisc_rules_engine/operations/variable_is_null.py b/cdisc_rules_engine/operations/variable_is_null.py index 58fa93197..18758ab38 100644 --- a/cdisc_rules_engine/operations/variable_is_null.py +++ b/cdisc_rules_engine/operations/variable_is_null.py @@ -21,4 +21,4 @@ def _is_target_variable_null(self, dataframe, target_variable: str) -> bool: if target_variable not in dataframe: return True series = dataframe[target_variable] - return series.mask(series == "").isnull().all() + return (series.isnull() | (series == "")).all() diff --git a/tests/unit/test_check_operators/test_containment_checks.py b/tests/unit/test_check_operators/test_containment_checks.py index 936fd6ed9..cb997005b 100644 --- a/tests/unit/test_check_operators/test_containment_checks.py +++ b/tests/unit/test_check_operators/test_containment_checks.py @@ -259,45 +259,85 @@ def test_not_contains_all(data, comparator, dataset_type, expected_result): @pytest.mark.parametrize( - "data,comparator,dataset_type, expected_result", + "data, comparator, dataset_type, column_prefix_map, value_is_literal, expected_result", [ ( {"target": ["Ctt", "Btt", "A"], "VAR2": ["A", "btt", "lll"]}, ["Ctt", "B", "A"], PandasDataset, + {}, + True, [True, False, True], ), ( {"target": ["Ctt", "Btt", "A"], "VAR2": ["A", "btt", "lll"]}, ["Ctt", "B", "A"], DaskDataset, + {}, + True, [True, False, True], ), ( {"target": ["A", "B", "C"]}, ["C", "Z", "A"], DaskDataset, + {}, + True, [True, False, True], ), ( {"target": [1, 2, 3], "VAR2": [[1, 2], [3], [3]]}, "VAR2", PandasDataset, + {}, + False, [True, False, True], ), ( {"target": [1, 2, 3], "VAR2": [[1, 2], [3], [3]]}, "VAR2", DaskDataset, + {}, + False, [True, False, True], ), + ( + { + "target": ["TSPARM", "TSORRESU", "AGEU", "TSDOSU"], + "DOMAIN": ["TS", "TS", "DM", "TS"], + }, + ["--ORRESU", "--STRESU", "--DOSU", "--TEST", "QLABEL", "--PARM"], + PandasDataset, + {"--": "TS"}, + True, + [True, True, False, True], + ), + ( + { + "target": ["TSPARM", "TSORRESU", "AGEU", "TSDOSU"], + "DOMAIN": ["TS", "TS", "DM", "TS"], + }, + ["--ORRESU", "--STRESU", "--DOSU", "--TEST", "QLABEL", "--PARM"], + DaskDataset, + {"--": "TS"}, + True, + [True, True, False, True], + ), ], ) -def test_is_contained_by(data, comparator, dataset_type, expected_result): +def test_is_contained_by( + data, comparator, dataset_type, column_prefix_map, value_is_literal, expected_result +): df = dataset_type.from_dict(data) - dataframe_operator = DataframeType({"value": df}) + dataframe_operator = DataframeType( + {"value": df, "column_prefix_map": column_prefix_map} + ) result = dataframe_operator.is_contained_by( - {"target": "target", "comparator": comparator} + { + "target": "target", + "comparator": comparator, + "value_is_literal": value_is_literal, + } ) assert result.equals(df.convert_to_series(expected_result)) diff --git a/tests/unit/test_operations/test_variable_is_null.py b/tests/unit/test_operations/test_variable_is_null.py index 286e57339..f2fe30cf6 100644 --- a/tests/unit/test_operations/test_variable_is_null.py +++ b/tests/unit/test_operations/test_variable_is_null.py @@ -8,41 +8,76 @@ @pytest.mark.parametrize( - "data, expected", + "data, target_var, expected", [ ( - PandasDataset.from_dict({"AEVAR": ["A", "B", "C"]}), + PandasDataset.from_dict({"VAR1": ["A", "B", "C"], "VAR2": [1, 2, 3]}), + "VAR1", False, ), ( - DaskDataset.from_dict({"AEVAR": [1, 2, 3]}), + PandasDataset.from_dict({"VAR1": ["", None, "C"], "VAR2": [1, 2, 3]}), + "VAR1", False, ), ( - DaskDataset.from_dict({"AEVAR": ["", None, "C"]}), + DaskDataset.from_dict({"VAR1": ["", None, "C"], "VAR2": [1, 2, 3]}), + "VAR1", False, ), ( - PandasDataset.from_dict({"AEVAR": [None, None, 3]}), + PandasDataset.from_dict({"VAR1": ["", None, ""], "VAR2": [1, 2, 3]}), + "VAR1", + True, + ), + ( + DaskDataset.from_dict({"VAR1": ["", None, ""], "VAR2": [1, 2, 3]}), + "VAR1", + True, + ), + ( + PandasDataset.from_dict( + {"VAR1": [None, None, None, "X"], "VAR2": [1, 2, 3, 4]} + ), + "VAR1", False, ), ( - PandasDataset.from_dict({"AEVAR": ["", None]}), + PandasDataset.from_dict( + {"VAR1": ["", "", "", "data"], "VAR2": [1, 2, 3, 4]} + ), + "VAR1", + False, + ), + ( + PandasDataset.from_dict({"VAR1": [None, None, None], "VAR2": [1, 2, 3]}), + "VAR1", + True, + ), + ( + PandasDataset.from_dict({"VAR2": ["A", "B", "C"]}), + "VAR1", + True, + ), + ( + DaskDataset.from_dict({"VAR2": ["A", "B", "C"]}), + "VAR1", True, ), ( - DaskDataset.from_dict({"BCVAR": ["A", "B", "C"]}), + PandasDataset.from_dict({"VAR2": ["A", "B", "C"]}), + "NONEXISTENT", True, ), ], ) def test_variable_is_null( - data, expected, mock_data_service, operation_params: OperationParams + data, target_var, expected, mock_data_service, operation_params: OperationParams ): config = ConfigService() cache = CacheServiceFactory(config).get_cache_service() operation_params.dataframe = data - operation_params.target = "AEVAR" + operation_params.target = target_var operation_params.domain = "AE" mock_data_service.get_dataset.return_value = data mock_data_service.dataset_implementation = data.__class__