Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/user_guide/measurement_user_guide.rst
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,7 @@ names with respect to pleasant and unpleasant attributes.

ethnicity_query = Query(
[word_sets["european_american_names_5"], word_sets["african_american_names_5"]],
[word_sets["pleasant_5"], word_sets["unpleasant_5"]],
[word_sets["pleasant_5"], word_sets["unpleasant_5a"]],
["European american names", "African american names"],
["Pleasant", "Unpleasant"],
)
Expand Down Expand Up @@ -1265,7 +1265,7 @@ Ethnicity Bias Model Ranking
# define the queries
ethnicity_query_1 = Query(
[word_sets["european_american_names_5"], word_sets["african_american_names_5"]],
[word_sets["pleasant_5"], word_sets["unpleasant_5"]],
[word_sets["pleasant_5"], word_sets["unpleasant_5a"]],
["European Names", "African Names"],
["Pleasant", "Unpleasant"],
)
Expand Down
4 changes: 2 additions & 2 deletions docs/user_guide/mitigation_user_guide.rst
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@ Next, we measure the gender bias exposed by query 2 (Male Names and Female Names

gender_query_2 = Query(
[weat_wordset["male_names"], weat_wordset["female_names"]],
[weat_wordset["pleasant_5"], weat_wordset["unpleasant_5"]],
[weat_wordset["pleasant_5"], weat_wordset["unpleasant_5a"]],
["Male Names", "Female Names"],
["Pleasant", "Unpleasant"],
)
Expand Down Expand Up @@ -433,7 +433,7 @@ equalized.

gender_query_2 = Query(
[weat_wordset["male_names"], weat_wordset["female_names"]],
[weat_wordset["pleasant_5"], weat_wordset["unpleasant_5"]],
[weat_wordset["pleasant_5"], weat_wordset["unpleasant_5a"]],
["Male Names", "Female Names"],
["Pleasant", "Unpleasant"],
)
Expand Down
10 changes: 5 additions & 5 deletions tests/debias/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def gender_query_1(weat_wordsets: dict[str, list[str]]) -> Query:
"""
query = Query(
[weat_wordsets["male_names"], weat_wordsets["female_names"]],
[weat_wordsets["pleasant_5"], weat_wordsets["unpleasant_5"]],
[weat_wordsets["pleasant_5"], weat_wordsets["unpleasant_5a"]],
["Male Names", "Female Names"],
["Pleasant", "Unpleasant"],
)
Expand Down Expand Up @@ -202,9 +202,9 @@ def ethnicity_query_1(weat_wordsets: dict[str, list[str]]) -> Query:
weat_wordsets["european_american_names_5"],
weat_wordsets["african_american_names_5"],
],
[weat_wordsets["pleasant_5"], weat_wordsets["unpleasant_5"]],
[weat_wordsets["pleasant_5"], weat_wordsets["unpleasant_5a"]],
["european_american_names_5", "african_american_names_5"],
["pleasant_5", "unpleasant_5"],
["pleasant_5", "unpleasant_5a"],
)
return query

Expand All @@ -226,8 +226,8 @@ def control_query_1(weat_wordsets: dict[str, list[str]]) -> Query:
"""
query = Query(
[weat_wordsets["flowers"], weat_wordsets["insects"]],
[weat_wordsets["pleasant_5"], weat_wordsets["unpleasant_5"]],
[weat_wordsets["pleasant_5"], weat_wordsets["unpleasant_5a"]],
["flowers", "insects"],
["pleasant_5", "unpleasant_5"],
["pleasant_5", "unpleasant_5a"],
)
return query
2 changes: 1 addition & 1 deletion tests/debias/test_double_hard_debias.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ def test_double_hard_debias_checks(
# )

# targets = weat_wordsets["male_names"] + weat_wordsets["female_names"]
# attributes = weat_wordsets["pleasant_5"] + weat_wordsets["unpleasant_5"]
# attributes = weat_wordsets["pleasant_5"] + weat_wordsets["unpleasant_5a"]
# ignore = targets + attributes

# gender_debiased_w2v = dhd.fit(
Expand Down
2 changes: 1 addition & 1 deletion tests/debias/test_half_sibling_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def test_half_sibling_regression_ignore_param(
)

targets = weat_wordsets["male_names"] + weat_wordsets["female_names"]
attributes = weat_wordsets["pleasant_5"] + weat_wordsets["unpleasant_5"]
attributes = weat_wordsets["pleasant_5"] + weat_wordsets["unpleasant_5a"]
ignore = targets + attributes

gender_debiased_w2v = hsr.fit(model, definitional_words=gender_specific).transform(
Expand Down
2 changes: 1 addition & 1 deletion tests/debias/test_hard_debias.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def test_hard_debias_ignore_param(
# this implies that neither of these words should be subjected to debias and
# therefore, both queries when executed with weat should return the same score.
targets = weat_wordsets["male_names"] + weat_wordsets["female_names"]
attributes = weat_wordsets["pleasant_5"] + weat_wordsets["unpleasant_5"]
attributes = weat_wordsets["pleasant_5"] + weat_wordsets["unpleasant_5a"]
ignore = targets + attributes

gender_debiased_w2v = hd.fit(
Expand Down
2 changes: 1 addition & 1 deletion tests/debias/test_multiclass_hard_debias.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ def test_multiclass_hard_debias_ignore_param(
# this implies that neither of these words should be subjected to debias and
# therefore, both queries when executed with weat should return the same score.
targets = weat_wordsets["male_names"] + weat_wordsets["female_names"]
attributes = weat_wordsets["pleasant_5"] + weat_wordsets["unpleasant_5"]
attributes = weat_wordsets["pleasant_5"] + weat_wordsets["unpleasant_5a"]
ignore = targets + attributes

gender_debiased_w2v = mhd.fit(
Expand Down
10 changes: 5 additions & 5 deletions tests/metrics/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def query_2t2a_1(weat_wordsets: dict[str, list[str]]) -> Query:
"""
query = Query(
[weat_wordsets["flowers"], weat_wordsets["insects"]],
[weat_wordsets["pleasant_5"], weat_wordsets["unpleasant_5"]],
[weat_wordsets["pleasant_5"], weat_wordsets["unpleasant_5a"]],
["Flowers", "Insects"],
["Pleasant", "Unpleasant"],
)
Expand All @@ -87,7 +87,7 @@ def query_3t2a_1(weat_wordsets: dict[str, list[str]]) -> Query:
weat_wordsets["insects"],
weat_wordsets["instruments"],
],
[weat_wordsets["pleasant_5"], weat_wordsets["unpleasant_5"]],
[weat_wordsets["pleasant_5"], weat_wordsets["unpleasant_5a"]],
["Flowers", "Weapons", "Instruments"],
["Pleasant", "Unpleasant"],
)
Expand All @@ -104,7 +104,7 @@ def query_4t2a_1(weat_wordsets: dict[str, list[str]]) -> Query:
weat_wordsets["instruments"],
weat_wordsets["weapons"],
],
[weat_wordsets["pleasant_5"], weat_wordsets["unpleasant_5"]],
[weat_wordsets["pleasant_5"], weat_wordsets["unpleasant_5a"]],
["Flowers", "Insects", "Instruments", "Weapons"],
["Pleasant", "Unpleasant"],
)
Expand All @@ -119,7 +119,7 @@ def query_1t4_1(weat_wordsets: dict[str, list[str]]) -> Query:
[
weat_wordsets["pleasant_5"],
weat_wordsets["pleasant_9"],
weat_wordsets["unpleasant_5"],
weat_wordsets["unpleasant_5a"],
weat_wordsets["unpleasant_9"],
],
["Flowers"],
Expand All @@ -144,7 +144,7 @@ def query_2t1a_lost_vocab_1(weat_wordsets: dict[str, list[str]]) -> Query:
def query_2t2a_lost_vocab_1(weat_wordsets: dict[str, list[str]]) -> Query:
query = Query(
[["bla", "asd"], weat_wordsets["insects"]],
[weat_wordsets["pleasant_5"], weat_wordsets["unpleasant_5"]],
[weat_wordsets["pleasant_5"], weat_wordsets["unpleasant_5a"]],
["Flowers", "Insects"],
["Pleasant", "Unpleasant"],
)
Expand Down
192 changes: 191 additions & 1 deletion tests/test_datasets.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,10 @@
import socket
import urllib.error

import pytest

from wefe.datasets.datasets import (
_retry_request,
fetch_debias_multiclass,
fetch_debiaswe,
fetch_eds,
Expand Down Expand Up @@ -132,11 +138,12 @@ def test_load_weat() -> None:
"flowers",
"insects",
"pleasant_5",
"unpleasant_5",
"unpleasant_5a",
"instruments",
"weapons",
"european_american_names_5",
"african_american_names_5",
"unpleasant_5b",
"european_american_names_7",
"african_american_names_7",
"pleasant_9",
Expand Down Expand Up @@ -180,3 +187,186 @@ def test_load_gn_glove() -> None:
for word in set_:
assert isinstance(word, str)
assert len(word) > 0


# Tests for retry functionality
class TestRetryRequest:
"""Test cases for the _retry_request function."""

def test_retry_request_success_on_first_attempt(self):
"""Test _retry_request result when function succeeds on first attempt."""
from unittest.mock import Mock

mock_func = Mock(return_value="success")

result = _retry_request(mock_func, "arg1", "arg2", kwarg1="value1")

assert result == "success"
mock_func.assert_called_once_with("arg1", "arg2", kwarg1="value1")

def test_retry_request_rate_limit_error(self, monkeypatch):
"""Test retry behavior for HTTP 429 rate limit errors."""
from unittest.mock import Mock

mock_sleep = Mock()
mock_warning = Mock()
monkeypatch.setattr("time.sleep", mock_sleep)
monkeypatch.setattr("logging.warning", mock_warning)

mock_func = Mock()

# Create HTTPError with code 429
from email.message import EmailMessage

headers = EmailMessage()
http_error = urllib.error.HTTPError(
url="http://test.com",
code=429,
msg="Too Many Requests",
hdrs=headers,
fp=None,
)

# First two calls fail with 429, third succeeds
mock_func.side_effect = [http_error, http_error, "success"]

result = _retry_request(mock_func, n_retries=3)

assert result == "success"
assert mock_func.call_count == 3
assert mock_sleep.call_count == 2
assert mock_warning.call_count == 2

# Check exponential backoff sleep times
mock_sleep.assert_any_call(1) # 2^0 = 1
mock_sleep.assert_any_call(2) # 2^1 = 2

def test_retry_request_timeout_error(self, monkeypatch):
"""Test retry behavior for timeout errors."""
from unittest.mock import Mock

mock_sleep = Mock()
mock_warning = Mock()
monkeypatch.setattr("time.sleep", mock_sleep)
monkeypatch.setattr("logging.warning", mock_warning)

mock_func = Mock()

# First call fails with timeout, second succeeds
mock_func.side_effect = [socket.timeout("Connection timeout"), "success"]

result = _retry_request(mock_func, n_retries=2)

assert result == "success"
assert mock_func.call_count == 2
mock_sleep.assert_called_once_with(1) # 2^0 = 1
mock_warning.assert_called_once()

def test_retry_request_timeout_error_os_error(self, monkeypatch):
"""Test retry behavior for OSError (network timeout)."""
from unittest.mock import Mock

mock_sleep = Mock()
mock_warning = Mock()
monkeypatch.setattr("time.sleep", mock_sleep)
monkeypatch.setattr("logging.warning", mock_warning)

mock_func = Mock()

# First call fails with OSError, second succeeds
mock_func.side_effect = [OSError("Network timeout"), "success"]

result = _retry_request(mock_func, n_retries=2)

assert result == "success"
assert mock_func.call_count == 2
mock_sleep.assert_called_once_with(1) # 2^0 = 1
mock_warning.assert_called_once()

def test_retry_request_generic_exception(self, monkeypatch):
"""Test retry behavior for generic exceptions."""
from unittest.mock import Mock

mock_sleep = Mock()
mock_warning = Mock()
monkeypatch.setattr("time.sleep", mock_sleep)
monkeypatch.setattr("logging.warning", mock_warning)

mock_func = Mock()

# First call fails with generic exception, second succeeds
mock_func.side_effect = [ValueError("Generic error"), "success"]

result = _retry_request(mock_func, n_retries=2)

assert result == "success"
assert mock_func.call_count == 2
mock_sleep.assert_called_once_with(1) # Fixed 1-second delay
mock_warning.assert_called_once()

def test_retry_request_non_retryable_http_error(self):
"""Test that non-retryable HTTP errors are not retried."""
from unittest.mock import Mock

mock_func = Mock()

# 404 Not Found should not be retried
from email.message import EmailMessage

headers = EmailMessage()
http_error = urllib.error.HTTPError(
url="http://test.com", code=404, msg="Not Found", hdrs=headers, fp=None
)
mock_func.side_effect = http_error

with pytest.raises(urllib.error.HTTPError) as exc_info:
_retry_request(mock_func, n_retries=3)

assert exc_info.value.code == 404
mock_func.assert_called_once() # Should only be called once

def test_retry_request_exhaust_retries(self, monkeypatch):
"""Test that function raises exception when all retries are exhausted."""
from unittest.mock import Mock

mock_sleep = Mock()
mock_warning = Mock()
monkeypatch.setattr("time.sleep", mock_sleep)
monkeypatch.setattr("logging.warning", mock_warning)

mock_func = Mock()

# Always fail with rate limit error
from email.message import EmailMessage

headers = EmailMessage()
http_error = urllib.error.HTTPError(
url="http://test.com",
code=429,
msg="Too Many Requests",
hdrs=headers,
fp=None,
)
mock_func.side_effect = http_error

with pytest.raises(urllib.error.HTTPError) as exc_info:
_retry_request(mock_func, n_retries=2)

assert exc_info.value.code == 429
assert mock_func.call_count == 3 # Initial call + 2 retries
assert mock_sleep.call_count == 2
assert mock_warning.call_count == 2

def test_retry_request_url_error(self):
"""Test that URLError without code is not retried."""
from unittest.mock import Mock

mock_func = Mock()

url_error = urllib.error.URLError("Connection failed")
mock_func.side_effect = url_error

with pytest.raises(urllib.error.URLError):
_retry_request(mock_func, n_retries=3)

mock_func.assert_called_once() # Should only be called once
Loading