From 1f2f2dcc5b7cc49146b500f56ee8b5967a397f6f Mon Sep 17 00:00:00 2001 From: Pablo Badilla Date: Tue, 22 Jul 2025 11:35:13 -0400 Subject: [PATCH] feature: update base metric --- tests/metrics/test_base_metric.py | 123 +++++++++--------- wefe/metrics/ECT.py | 7 +- wefe/metrics/MAC.py | 7 +- wefe/metrics/RIPA.py | 7 +- wefe/metrics/RND.py | 7 +- wefe/metrics/RNSB.py | 7 +- wefe/metrics/WEAT.py | 7 +- wefe/metrics/base_metric.py | 204 +++++++++++++++++++----------- 8 files changed, 232 insertions(+), 137 deletions(-) diff --git a/tests/metrics/test_base_metric.py b/tests/metrics/test_base_metric.py index 0703cff..6b617a2 100644 --- a/tests/metrics/test_base_metric.py +++ b/tests/metrics/test_base_metric.py @@ -1,3 +1,5 @@ +"""Unit tests for the BaseMetric class in the wefe.metrics.base_metric module.""" + import pytest from wefe.metrics.base_metric import BaseMetric @@ -10,6 +12,21 @@ def test_base_metric_input_checking( query_2t2a_1: Query, query_3t2a_1: Query, ) -> None: + """Test input validation for the BaseMetric class. + This test verifies that the `_check_input` method of `BaseMetric` correctly + raises exceptions when provided with invalid inputs, such as incorrect types + for `query` and `model`, or mismatched cardinalities for target and attribute sets. + + Parameters + ---------- + model : WordEmbeddingModel + A word embedding model instance used for metric evaluation. + query_2t2a_1 : Query + A query instance with 2 target sets and 2 attribute sets. + query_3t2a_1 : Query + A query instance with 3 target sets and 2 attribute sets. + + """ # Create and configure base metric testing. # disable abstract methods. @@ -21,84 +38,74 @@ def test_base_metric_input_checking( base_metric.metric_short_name = "TM" with pytest.raises(TypeError, match="query should be a Query instance, got*"): - base_metric._check_input(None, model, {}) - - with pytest.raises( - TypeError, match="word_embedding should be a WordEmbeddingModel instance, got*" - ): - base_metric._check_input(query_2t2a_1, None, {}) - - with pytest.raises( - Exception, - match="The cardinality of the set of target words of the 'Flowers, Weapons and " - "Instruments wrt Pleasant and Unpleasant' query does not match with the " - "cardinality required by TM. Provided query: 3, metric: 2", - ): - base_metric._check_input(query_3t2a_1, model, {}) - - with pytest.raises( - Exception, - match=( - "The cardinality of the set of attribute words of the 'Flowers and Insects " - "wrt Pleasant and Unpleasant' query does not match with the cardinality " - "required by TM. Provided query: 2, metric: 3" - ), - ): - base_metric._check_input(query_2t2a_1, model, {}) - - -def test_validate_old_preprocessor_args_inputs( - model: WordEmbeddingModel, - query_2t2a_1: Query, -) -> None: - # instance test metric - BaseMetric.__abstractmethods__ = frozenset() - base_metric = BaseMetric() - base_metric.metric_template = (2, 2) - base_metric.metric_name = "Test Metric" - base_metric.metric_short_name = "TM" + base_metric._check_input( + query=None, + model=model, + lost_vocabulary_threshold=0.2, + warn_not_found_words=True, + ) with pytest.raises( - DeprecationWarning, - match=( - r"preprocessor_args argument is deprecated. " - r"Use preprocessors=\[\{'uppercase': True\}\] instead.*." - ), + TypeError, + match="model should be a WordEmbeddingModel instance, got: .", ): base_metric._check_input( - query_2t2a_1, model, {"preprocessor_args": {"uppercase": True}} + query=query_2t2a_1, + model=None, + lost_vocabulary_threshold=0.2, + warn_not_found_words=True, ) with pytest.raises( - DeprecationWarning, + Exception, match=( - r"secondary_preprocessor_args is deprecated. " - r"Use preprocessors=\[\{\}, \{'uppercase': True\}\] instead.*." + r"The cardinality of the target sets of the 'Flowers, Weapons and " + r"Instruments wrt Pleasant and Unpleasant' query \(3\) does not match " + r"the cardinality required by TM \(2\)." ), ): base_metric._check_input( - query_2t2a_1, model, {"secondary_preprocessor_args": {"uppercase": True}} + query=query_3t2a_1, + model=model, + lost_vocabulary_threshold=0.2, + warn_not_found_words=True, ) with pytest.raises( - DeprecationWarning, + Exception, match=( - r"preprocessor_args and secondary_preprocessor_args arguments are " - r"deprecated. Use preprocessors=\[\{'uppercase': True\}, \{'uppercase': " - r"True\}\] instead." + r"The cardinality of the attribute sets of the 'Flowers and Insects wrt " + r"Pleasant and Unpleasant' query \(2\) does not match the cardinality " + r"required by TM \(3\)." ), ): base_metric._check_input( - query_2t2a_1, - model, - { - "preprocessor_args": {"uppercase": True}, - "secondary_preprocessor_args": {"uppercase": True}, - }, + query=query_2t2a_1, + model=model, + lost_vocabulary_threshold=0.2, + warn_not_found_words=True, ) -def test_run_query(model: WordEmbeddingModel, query_2t2a_1: Query) -> None: +def test_run_query( + model: WordEmbeddingModel, + query_2t2a_1: Query, +) -> None: + """Test that the `run_query` method of `BaseMetric` raises a NotImplementedError. + + Parameters + ---------- + model : WordEmbeddingModel + The word embedding model to be used in the query. + query_2t2a_1 : Query + The query object to be passed to the metric. + + Raises + ------ + NotImplementedError + If the `run_query` method is not implemented in `BaseMetric`. + + """ # disable abstract methods. BaseMetric.__abstractmethods__ = frozenset() base_metric = BaseMetric() @@ -110,4 +117,4 @@ def test_run_query(model: WordEmbeddingModel, query_2t2a_1: Query) -> None: with pytest.raises( NotImplementedError, ): - base_metric.run_query(query_2t2a_1, model) + base_metric.run_query(query=query_2t2a_1, model=model) diff --git a/wefe/metrics/ECT.py b/wefe/metrics/ECT.py index 08c01f8..20d71d3 100644 --- a/wefe/metrics/ECT.py +++ b/wefe/metrics/ECT.py @@ -160,7 +160,12 @@ def run_query( """ # check the types of the provided arguments (only the defaults). - self._check_input(query, model, locals()) + self._check_input( + query=query, + model=model, + lost_vocabulary_threshold=lost_vocabulary_threshold, + warn_not_found_words=warn_not_found_words, + ) # transform query word sets into embeddings embeddings = get_embeddings_from_query( diff --git a/wefe/metrics/MAC.py b/wefe/metrics/MAC.py index 66efa53..5004b62 100644 --- a/wefe/metrics/MAC.py +++ b/wefe/metrics/MAC.py @@ -448,7 +448,12 @@ def run_query( """ # noqa: E501 # check the types of the provided arguments (only the defaults). - self._check_input(query, model, locals()) + self._check_input( + query=query, + model=model, + lost_vocabulary_threshold=lost_vocabulary_threshold, + warn_not_found_words=warn_not_found_words, + ) # transform query word sets into embeddings embeddings = get_embeddings_from_query( diff --git a/wefe/metrics/RIPA.py b/wefe/metrics/RIPA.py index 5b528c2..a974784 100644 --- a/wefe/metrics/RIPA.py +++ b/wefe/metrics/RIPA.py @@ -268,7 +268,12 @@ def run_query( """ # check the types of the provided arguments (only the defaults). - self._check_input(query, model, locals()) + self._check_input( + query=query, + model=model, + lost_vocabulary_threshold=lost_vocabulary_threshold, + warn_not_found_words=warn_not_found_words, + ) # transform query word sets into embeddings embeddings = get_embeddings_from_query( diff --git a/wefe/metrics/RND.py b/wefe/metrics/RND.py index 89ae306..4b6d258 100644 --- a/wefe/metrics/RND.py +++ b/wefe/metrics/RND.py @@ -261,7 +261,12 @@ def run_query( """ # check the types of the provided arguments (only the defaults). - self._check_input(query, model, locals()) + self._check_input( + query=query, + model=model, + lost_vocabulary_threshold=lost_vocabulary_threshold, + warn_not_found_words=warn_not_found_words, + ) # transform query word sets into embeddings embeddings = get_embeddings_from_query( diff --git a/wefe/metrics/RNSB.py b/wefe/metrics/RNSB.py index 99b8527..2a500ef 100644 --- a/wefe/metrics/RNSB.py +++ b/wefe/metrics/RNSB.py @@ -792,7 +792,12 @@ def run_query( """ # check the types of the provided arguments (only the defaults). - self._check_input(query, model, locals()) + self._check_input( + query=query, + model=model, + lost_vocabulary_threshold=lost_vocabulary_threshold, + warn_not_found_words=warn_not_found_words, + ) if n_iterations > 1 and random_state is not None: raise ValueError( diff --git a/wefe/metrics/WEAT.py b/wefe/metrics/WEAT.py index 85c62ac..1120999 100644 --- a/wefe/metrics/WEAT.py +++ b/wefe/metrics/WEAT.py @@ -439,7 +439,12 @@ def run_query( """ # check the types of the provided arguments (only the defaults). - self._check_input(query, model, locals()) + self._check_input( + query=query, + model=model, + lost_vocabulary_threshold=lost_vocabulary_threshold, + warn_not_found_words=warn_not_found_words, + ) # transform query word sets into embeddings embeddings = get_embeddings_from_query( diff --git a/wefe/metrics/base_metric.py b/wefe/metrics/base_metric.py index 019a5f3..79a0790 100644 --- a/wefe/metrics/base_metric.py +++ b/wefe/metrics/base_metric.py @@ -1,33 +1,52 @@ """Base metric class that all metrics must extend..""" from abc import ABC, abstractmethod -from typing import Any, Callable, Union +from typing import Any, Callable, ClassVar, Union from wefe.query import Query from wefe.word_embedding_model import WordEmbeddingModel class BaseMetric(ABC): - """A base class to implement any metric following the framework described by WEFE. + """An abstract base class for implementing fairness metrics in the WEFE. - It contains the name of the metric, the templates (cardinalities) that it supports - and the abstract function run_query, which must be implemented by any metric that - extends this class. - """ + This class provides a template for all metrics, ensuring consistent input + validation and parameter handling. Subclasses are required to define metric-specific + attributes and implement the core calculation logic in the `run_query` + method. + + Attributes + ---------- + metric_template : ClassVar[tuple[Union[int, str], Union[int, str]]] + A tuple indicating the required cardinality of target and attribute sets, + e.g., (1, 1) or (2, 'n'). 'n' denotes any number of sets. + This must be overridden by subclasses. - # A tuple that indicates the cardinality of target and attribute sets - metric_template: tuple[Union[int, str], Union[int, str]] + metric_name : ClassVar[str] + The full name of the metric. + This must be overridden by subclasses. - # The name of the metric - metric_name: str + metric_short_name : ClassVar[str] + The initials or a short name for the metric. + This must be overridden by subclasses. + + """ - # The initials or short name of the metric - metric_short_name: str + # These attributes MUST be overridden by any class that extends BaseMetric. + metric_template: ClassVar[tuple[Union[int, str], Union[int, str]]] + metric_name: ClassVar[str] + metric_short_name: ClassVar[str] def _check_input( - self, query: Query, model: WordEmbeddingModel, _locals: dict[str, Any] + self, + query: Query, + model: WordEmbeddingModel, + lost_vocabulary_threshold: float, + warn_not_found_words: bool, ) -> None: - """Check if Query and WordEmbeddingModel parameters are valid. + """Check if the parameters for run_query are valid. + + This private method is called by the `run_query` template method. Parameters ---------- @@ -35,89 +54,66 @@ def _check_input( The query that the method will execute. model : WordEmbeddingModel A word embedding model. - _locals: Dict[str, Any] - The extra arguments of run_query. + lost_vocabulary_threshold : float + The threshold for the proportion of lost words. + warn_not_found_words : bool + Specifies whether to warn about out-of-vocabulary words. + Raises ------ TypeError - if query is not instance of Query. + If `query` is not an instance of `Query`. TypeError - if word_embedding is not instance of . + If `model` is not an instance of `WordEmbeddingModel`. TypeError - if lost_vocabulary_threshold is not a float number. + If `lost_vocabulary_threshold` is not a float. TypeError - if warn_filtered_words is not a bool. - Exception - if the metric require different number of target sets than - the delivered query - Exception - if the metric require different number of attribute sets than - the delivered query + If `warn_not_found_words` is not a bool. + ValueError + If the query's template cardinality does not match the metric's + required template. """ - # check if the query passed is a instance of Query if not isinstance(query, Query): - raise TypeError(f"query should be a Query instance, got {query}") - - # check if the word_embedding is a instance of + raise TypeError(f"query should be a Query instance, got: {type(query)}.") if not isinstance(model, WordEmbeddingModel): raise TypeError( - f"word_embedding should be a WordEmbeddingModel instance, got: {model}" + f"model should be a WordEmbeddingModel instance, got: {type(model)}." ) - # templates: + if not isinstance(lost_vocabulary_threshold, float): + raise TypeError( + f"lost_vocabulary_threshold should be a float, got: " + f"{type(lost_vocabulary_threshold)}." + ) - # check the cardinality of the target sets of the provided query + if not isinstance(warn_not_found_words, bool): + raise TypeError( + f"warn_not_found_words should be a bool, got: " + f"{type(warn_not_found_words)}." + ) + + # Check the cardinality of the target sets if ( self.metric_template[0] != "n" and query.template[0] != self.metric_template[0] ): - raise Exception( - f"The cardinality of the set of target words of the " - f"'{query.query_name}' query does not match with the cardinality " - f"required by {self.metric_short_name}. Provided query: " - f"{query.template[0]}, metric: {self.metric_template[0]}." + raise ValueError( + f"The cardinality of the target sets of the '{query.query_name}' " + f"query ({query.template[0]}) does not match the cardinality " + f"required by {self.metric_short_name} ({self.metric_template[0]})." ) - # check the cardinality of the attribute sets of the provided query + # Check the cardinality of the attribute sets if ( self.metric_template[1] != "n" and query.template[1] != self.metric_template[1] ): - raise Exception( - "The cardinality of the set of attribute words of the " - f"'{query.query_name}' query does not match with the cardinality " - f"required by {self.metric_short_name}. " - f"Provided query: {query.template[1]}, metric: " - f"{self.metric_template[1]}." - ) - - preprocessor_in_args = "preprocessor_args" in _locals - secondary_preprocessor_in_args = "secondary_preprocessor_args" in _locals - - if preprocessor_in_args and secondary_preprocessor_in_args: - raise DeprecationWarning( - "preprocessor_args and secondary_preprocessor_args arguments are " - "deprecated. Use " - f"preprocessors=[{_locals['preprocessor_args']}, " - f"{_locals['secondary_preprocessor_args']}] " - "instead.\n\nSee https://wefe.readthedocs.io/en/latest/user_guide_" - "measurement.html#word-preprocessors for more information." - ) - if preprocessor_in_args: - raise DeprecationWarning( - "preprocessor_args argument is deprecated. Use " - f"preprocessors=[{_locals['preprocessor_args']}] " - "instead.\n\nSee https://wefe.readthedocs.io/en/latest/user_guide_" - "measurement.html#word-preprocessors for more information." - ) - if secondary_preprocessor_in_args: - raise DeprecationWarning( - "secondary_preprocessor_args is deprecated. Use " - f"preprocessors=[{{}}, {_locals['secondary_preprocessor_args']}] " - "instead.\n\nSee https://wefe.readthedocs.io/en/latest/user_guide_" - "measurement.html#word-preprocessors for more information." + raise ValueError( + f"The cardinality of the attribute sets of the '{query.query_name}' " + f"query ({query.template[1]}) does not match the cardinality " + f"required by {self.metric_short_name} ({self.metric_template[1]})." ) @abstractmethod @@ -126,11 +122,73 @@ def run_query( query: Query, model: WordEmbeddingModel, lost_vocabulary_threshold: float = 0.2, - preprocessors: list[dict[str, Union[str, bool, Callable]]] = [{}], + preprocessors: list[dict[str, str | bool | Callable]] = [{}], strategy: str = "first", normalize: bool = False, warn_not_found_words: bool = False, *args: Any, **kwargs: Any, ) -> dict[str, Any]: + """Runs the metric on the given query and model. + + Parameters + ---------- + query : Query + A Query object that contains the target and attribute word sets to + be tested. + model : WordEmbeddingModel + A word embedding model. + lost_vocabulary_threshold : float, optional + Specifies the proportional limit of words that any set of the query is + allowed to lose when transforming its words into embeddings. + In the case that any set of the query loses proportionally more words + than this limit, the result values will be np.nan, by default 0.2. + preprocessors : list[dict[str, str | bool | Callable]] + A list with preprocessor options. + + A ``preprocessor`` is a dictionary that specifies what processing(s) are + performed on each word before it is looked up in the model vocabulary. + For example, the ``preprocessor`` + ``{'lowecase': True, 'strip_accents': True}`` allows you to lowercase + and remove the accent from each word before searching for them in the + model vocabulary. Note that an empty dictionary ``{}`` indicates that no + preprocessing is done. + + The possible options for a preprocessor are: + + * ``lowercase``: ``bool``. Indicates that the words are transformed to + lowercase. + * ``uppercase``: ``bool``. Indicates that the words are transformed to + uppercase. + * ``titlecase``: ``bool``. Indicates that the words are transformed to + titlecase. + * ``strip_accents``: ``bool``, ``{'ascii', 'unicode'}``: Specifies that + the accents of the words are eliminated. The stripping type can be + specified. True uses 'unicode' by default. + * ``preprocessor``: ``Callable``. It receives a function that operates + on each word. In the case of specifying a function, it overrides the + default preprocessor (i.e., the previous options stop working). + A list of preprocessor options allows you to search for several + variants of the words into the model. For example, the preprocessors + ``[{}, {"lowercase": True, "strip_accents": True}]`` + ``{}`` allows searching first for the original words in the vocabulary of + the model. In case some of them are not found, + ``{"lowercase": True, "strip_accents": True}`` is executed on these words + and then they are searched in the model vocabulary. + strategy : str, optional + The strategy indicates how it will use the preprocessed words: 'first' will + include only the first transformed word found. 'all' will include all + transformed words found, by default "first". + normalize : bool, optional + True indicates that embeddings will be normalized, by default False + warn_not_found_words : bool, optional + Specifies if the function will warn (in the logger) + the words that were not found in the model's vocabulary, by default False. + + Returns + ------- + dict[str, Any] + A dictionary containing the results of the metric. + + """ raise NotImplementedError()