Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 65 additions & 58 deletions tests/metrics/test_base_metric.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
"""Unit tests for the BaseMetric class in the wefe.metrics.base_metric module."""

import pytest

from wefe.metrics.base_metric import BaseMetric
Expand All @@ -10,6 +12,21 @@ def test_base_metric_input_checking(
query_2t2a_1: Query,
query_3t2a_1: Query,
) -> None:
"""Test input validation for the BaseMetric class.
This test verifies that the `_check_input` method of `BaseMetric` correctly
raises exceptions when provided with invalid inputs, such as incorrect types
for `query` and `model`, or mismatched cardinalities for target and attribute sets.

Parameters
----------
model : WordEmbeddingModel
A word embedding model instance used for metric evaluation.
query_2t2a_1 : Query
A query instance with 2 target sets and 2 attribute sets.
query_3t2a_1 : Query
A query instance with 3 target sets and 2 attribute sets.

"""
# Create and configure base metric testing.
# disable abstract methods.

Expand All @@ -21,84 +38,74 @@ def test_base_metric_input_checking(
base_metric.metric_short_name = "TM"

with pytest.raises(TypeError, match="query should be a Query instance, got*"):
base_metric._check_input(None, model, {})

with pytest.raises(
TypeError, match="word_embedding should be a WordEmbeddingModel instance, got*"
):
base_metric._check_input(query_2t2a_1, None, {})

with pytest.raises(
Exception,
match="The cardinality of the set of target words of the 'Flowers, Weapons and "
"Instruments wrt Pleasant and Unpleasant' query does not match with the "
"cardinality required by TM. Provided query: 3, metric: 2",
):
base_metric._check_input(query_3t2a_1, model, {})

with pytest.raises(
Exception,
match=(
"The cardinality of the set of attribute words of the 'Flowers and Insects "
"wrt Pleasant and Unpleasant' query does not match with the cardinality "
"required by TM. Provided query: 2, metric: 3"
),
):
base_metric._check_input(query_2t2a_1, model, {})


def test_validate_old_preprocessor_args_inputs(
model: WordEmbeddingModel,
query_2t2a_1: Query,
) -> None:
# instance test metric
BaseMetric.__abstractmethods__ = frozenset()
base_metric = BaseMetric()
base_metric.metric_template = (2, 2)
base_metric.metric_name = "Test Metric"
base_metric.metric_short_name = "TM"
base_metric._check_input(
query=None,
model=model,
lost_vocabulary_threshold=0.2,
warn_not_found_words=True,
)

with pytest.raises(
DeprecationWarning,
match=(
r"preprocessor_args argument is deprecated. "
r"Use preprocessors=\[\{'uppercase': True\}\] instead.*."
),
TypeError,
match="model should be a WordEmbeddingModel instance, got: <class 'NoneType'>.",
):
base_metric._check_input(
query_2t2a_1, model, {"preprocessor_args": {"uppercase": True}}
query=query_2t2a_1,
model=None,
lost_vocabulary_threshold=0.2,
warn_not_found_words=True,
)

with pytest.raises(
DeprecationWarning,
Exception,
match=(
r"secondary_preprocessor_args is deprecated. "
r"Use preprocessors=\[\{\}, \{'uppercase': True\}\] instead.*."
r"The cardinality of the target sets of the 'Flowers, Weapons and "
r"Instruments wrt Pleasant and Unpleasant' query \(3\) does not match "
r"the cardinality required by TM \(2\)."
),
):
base_metric._check_input(
query_2t2a_1, model, {"secondary_preprocessor_args": {"uppercase": True}}
query=query_3t2a_1,
model=model,
lost_vocabulary_threshold=0.2,
warn_not_found_words=True,
)

with pytest.raises(
DeprecationWarning,
Exception,
match=(
r"preprocessor_args and secondary_preprocessor_args arguments are "
r"deprecated. Use preprocessors=\[\{'uppercase': True\}, \{'uppercase': "
r"True\}\] instead."
r"The cardinality of the attribute sets of the 'Flowers and Insects wrt "
r"Pleasant and Unpleasant' query \(2\) does not match the cardinality "
r"required by TM \(3\)."
),
):
base_metric._check_input(
query_2t2a_1,
model,
{
"preprocessor_args": {"uppercase": True},
"secondary_preprocessor_args": {"uppercase": True},
},
query=query_2t2a_1,
model=model,
lost_vocabulary_threshold=0.2,
warn_not_found_words=True,
)


def test_run_query(model: WordEmbeddingModel, query_2t2a_1: Query) -> None:
def test_run_query(
model: WordEmbeddingModel,
query_2t2a_1: Query,
) -> None:
"""Test that the `run_query` method of `BaseMetric` raises a NotImplementedError.

Parameters
----------
model : WordEmbeddingModel
The word embedding model to be used in the query.
query_2t2a_1 : Query
The query object to be passed to the metric.

Raises
------
NotImplementedError
If the `run_query` method is not implemented in `BaseMetric`.

"""
# disable abstract methods.
BaseMetric.__abstractmethods__ = frozenset()
base_metric = BaseMetric()
Expand All @@ -110,4 +117,4 @@ def test_run_query(model: WordEmbeddingModel, query_2t2a_1: Query) -> None:
with pytest.raises(
NotImplementedError,
):
base_metric.run_query(query_2t2a_1, model)
base_metric.run_query(query=query_2t2a_1, model=model)
7 changes: 6 additions & 1 deletion wefe/metrics/ECT.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,12 @@ def run_query(

"""
# check the types of the provided arguments (only the defaults).
self._check_input(query, model, locals())
self._check_input(
query=query,
model=model,
lost_vocabulary_threshold=lost_vocabulary_threshold,
warn_not_found_words=warn_not_found_words,
)

# transform query word sets into embeddings
embeddings = get_embeddings_from_query(
Expand Down
7 changes: 6 additions & 1 deletion wefe/metrics/MAC.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,7 +448,12 @@ def run_query(

""" # noqa: E501
# check the types of the provided arguments (only the defaults).
self._check_input(query, model, locals())
self._check_input(
query=query,
model=model,
lost_vocabulary_threshold=lost_vocabulary_threshold,
warn_not_found_words=warn_not_found_words,
)

# transform query word sets into embeddings
embeddings = get_embeddings_from_query(
Expand Down
7 changes: 6 additions & 1 deletion wefe/metrics/RIPA.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,12 @@ def run_query(

"""
# check the types of the provided arguments (only the defaults).
self._check_input(query, model, locals())
self._check_input(
query=query,
model=model,
lost_vocabulary_threshold=lost_vocabulary_threshold,
warn_not_found_words=warn_not_found_words,
)

# transform query word sets into embeddings
embeddings = get_embeddings_from_query(
Expand Down
7 changes: 6 additions & 1 deletion wefe/metrics/RND.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,12 @@ def run_query(

"""
# check the types of the provided arguments (only the defaults).
self._check_input(query, model, locals())
self._check_input(
query=query,
model=model,
lost_vocabulary_threshold=lost_vocabulary_threshold,
warn_not_found_words=warn_not_found_words,
)

# transform query word sets into embeddings
embeddings = get_embeddings_from_query(
Expand Down
7 changes: 6 additions & 1 deletion wefe/metrics/RNSB.py
Original file line number Diff line number Diff line change
Expand Up @@ -792,7 +792,12 @@ def run_query(

"""
# check the types of the provided arguments (only the defaults).
self._check_input(query, model, locals())
self._check_input(
query=query,
model=model,
lost_vocabulary_threshold=lost_vocabulary_threshold,
warn_not_found_words=warn_not_found_words,
)

if n_iterations > 1 and random_state is not None:
raise ValueError(
Expand Down
7 changes: 6 additions & 1 deletion wefe/metrics/WEAT.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,7 +439,12 @@ def run_query(

"""
# check the types of the provided arguments (only the defaults).
self._check_input(query, model, locals())
self._check_input(
query=query,
model=model,
lost_vocabulary_threshold=lost_vocabulary_threshold,
warn_not_found_words=warn_not_found_words,
)

# transform query word sets into embeddings
embeddings = get_embeddings_from_query(
Expand Down
Loading