From 6ca5e0efc6ff6b5b702977dc7bf14d6e36364236 Mon Sep 17 00:00:00 2001 From: mali-git Date: Thu, 12 Dec 2024 11:15:51 +0100 Subject: [PATCH 1/5] feat: Add data model --- src/ml_filter/data_models.py | 67 ++++++++++++++++++++++++++++++++++++ 1 file changed, 67 insertions(+) create mode 100644 src/ml_filter/data_models.py diff --git a/src/ml_filter/data_models.py b/src/ml_filter/data_models.py new file mode 100644 index 00000000..4fb364b9 --- /dev/null +++ b/src/ml_filter/data_models.py @@ -0,0 +1,67 @@ +from enum import Enum +from typing import Dict, Union + +from pydantic import BaseModel, Field + + +# Define DecodingStrategy Enum +class DecodingStrategy(str, Enum): + greedy = "greedy" + beam_search = "beam_search" + top_k = "top_k" + top_p = "top_p" + + +# Base class for decoding strategy parameters +class DecodingParameters(BaseModel): + strategy: DecodingStrategy + + +# Decoding strategy parameter classes +class GreedyParameters(DecodingParameters): + strategy: DecodingStrategy = Field(default=DecodingStrategy.greedy) + + +class BeamSearchParameters(DecodingParameters): + strategy: DecodingStrategy = Field(default=DecodingStrategy.beam_search) + num_beams: int + early_stopping: bool + + +class TopKParameters(DecodingParameters): + strategy: DecodingStrategy = Field(default=DecodingStrategy.top_k) + top_k: int + temperature: float + + +class TopPParameters(DecodingParameters): + strategy: DecodingStrategy = Field(default=DecodingStrategy.top_p) + top_p: float + temperature: float + + +# General Information about a document +class DocumentInfo(BaseModel): + document_id: str + prompt: str + prompt_lang: str + raw_data_path: str + model: str + decoding_parameters: Union[GreedyParameters, BeamSearchParameters, TopKParameters, TopPParameters] + + +# Statistical correlations for performance evaluation +class CorrelationMetrics(BaseModel): + correlation: Dict[str, Dict[str, float]] # Correlation per ground truth approach + + +# T-Test and p-value results for performance evaluation +class TTestResults(BaseModel): + t_test_p_values: Dict[str, float] # p-values for each ground truth approach + + +# Complete statistical report combining various metrics +class StatisticReport(BaseModel): + document_info: DocumentInfo + correlation_metrics: CorrelationMetrics + t_test_results: TTestResults From ae7d6270ba641a46aeb7e6a77b33484d1eda5ab7 Mon Sep 17 00:00:00 2001 From: mali-git Date: Thu, 12 Dec 2024 11:23:17 +0100 Subject: [PATCH 2/5] docs: add docstrings --- src/ml_filter/data_models.py | 39 ++++++++++++++++++++++++++---------- 1 file changed, 28 insertions(+), 11 deletions(-) diff --git a/src/ml_filter/data_models.py b/src/ml_filter/data_models.py index 4fb364b9..8b2d3240 100644 --- a/src/ml_filter/data_models.py +++ b/src/ml_filter/data_models.py @@ -6,42 +6,56 @@ # Define DecodingStrategy Enum class DecodingStrategy(str, Enum): - greedy = "greedy" - beam_search = "beam_search" - top_k = "top_k" - top_p = "top_p" + """Decoding strategies for text generation models""" + + GREEDY = "greedy" + BEAM_SEARCH = "beam_search" + TOP_K = "top_k" + TOP_P = "top_p" # Base class for decoding strategy parameters class DecodingParameters(BaseModel): + """Decoding strategy parameters""" + strategy: DecodingStrategy # Decoding strategy parameter classes class GreedyParameters(DecodingParameters): - strategy: DecodingStrategy = Field(default=DecodingStrategy.greedy) + """Greedy decoding strategy parameters""" + + strategy: DecodingStrategy = Field(default=DecodingStrategy.GREEDY) class BeamSearchParameters(DecodingParameters): - strategy: DecodingStrategy = Field(default=DecodingStrategy.beam_search) + """Beam search decoding strategy parameters""" + + strategy: DecodingStrategy = Field(default=DecodingStrategy.BEAM_SEARCH) num_beams: int early_stopping: bool class TopKParameters(DecodingParameters): - strategy: DecodingStrategy = Field(default=DecodingStrategy.top_k) + """Top-K decoding strategy parameters""" + + strategy: DecodingStrategy = Field(default=DecodingStrategy.TOP_K) top_k: int temperature: float class TopPParameters(DecodingParameters): - strategy: DecodingStrategy = Field(default=DecodingStrategy.top_p) + """Top-P decoding strategy parameters""" + + strategy: DecodingStrategy = Field(default=DecodingStrategy.TOP_P) top_p: float temperature: float # General Information about a document class DocumentInfo(BaseModel): + """General information about a document""" + document_id: str prompt: str prompt_lang: str @@ -50,18 +64,21 @@ class DocumentInfo(BaseModel): decoding_parameters: Union[GreedyParameters, BeamSearchParameters, TopKParameters, TopPParameters] -# Statistical correlations for performance evaluation class CorrelationMetrics(BaseModel): + """Correlation metrics for performance evaluation""" + correlation: Dict[str, Dict[str, float]] # Correlation per ground truth approach -# T-Test and p-value results for performance evaluation class TTestResults(BaseModel): + """T-Test results for performance evaluation""" + t_test_p_values: Dict[str, float] # p-values for each ground truth approach -# Complete statistical report combining various metrics class StatisticReport(BaseModel): + """Complete statistical report combining various metrics""" + document_info: DocumentInfo correlation_metrics: CorrelationMetrics t_test_results: TTestResults From e6c3063fee208fababfcae7316fcc43d09b5fdde Mon Sep 17 00:00:00 2001 From: mali-git Date: Thu, 12 Dec 2024 11:46:55 +0100 Subject: [PATCH 3/5] refactor: add constraints --- src/ml_filter/data_models.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/src/ml_filter/data_models.py b/src/ml_filter/data_models.py index 8b2d3240..6cfa8f05 100644 --- a/src/ml_filter/data_models.py +++ b/src/ml_filter/data_models.py @@ -32,7 +32,7 @@ class BeamSearchParameters(DecodingParameters): """Beam search decoding strategy parameters""" strategy: DecodingStrategy = Field(default=DecodingStrategy.BEAM_SEARCH) - num_beams: int + num_beams: int = Field(..., gt=0, description="Number of beams must be greater than 0.") early_stopping: bool @@ -40,16 +40,18 @@ class TopKParameters(DecodingParameters): """Top-K decoding strategy parameters""" strategy: DecodingStrategy = Field(default=DecodingStrategy.TOP_K) - top_k: int - temperature: float + top_k: int = Field(..., gt=0, description="Number of top candidates to consider. Must be greater than 0.") + temperature: float = Field(..., gt=0, description="Sampling temperature. Must be greater than 0.") class TopPParameters(DecodingParameters): """Top-P decoding strategy parameters""" strategy: DecodingStrategy = Field(default=DecodingStrategy.TOP_P) - top_p: float - temperature: float + top_p: float = Field( + ..., gt=0, le=1, description="Cumulative probability for nucleus sampling. Must be in the range (0, 1]." + ) + temperature: float = Field(..., gt=0, description="Sampling temperature. Must be greater than 0.") # General Information about a document From 73065faf184fbaa25198b4d58eb5515a615cd09a Mon Sep 17 00:00:00 2001 From: mali-git Date: Thu, 12 Dec 2024 11:47:15 +0100 Subject: [PATCH 4/5] test: test data models --- tests/test_data_models.py | 103 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 103 insertions(+) create mode 100644 tests/test_data_models.py diff --git a/tests/test_data_models.py b/tests/test_data_models.py new file mode 100644 index 00000000..65a2ef45 --- /dev/null +++ b/tests/test_data_models.py @@ -0,0 +1,103 @@ +import pytest +from pydantic import ValidationError + +from ml_filter.data_models import ( + BeamSearchParameters, + CorrelationMetrics, + DecodingStrategy, + DocumentInfo, + GreedyParameters, + StatisticReport, + TopKParameters, + TopPParameters, + TTestResults, +) + + +def test_greedy_parameters(): + params = GreedyParameters() + assert params.strategy == DecodingStrategy.GREEDY + + +def test_beam_search_parameters(): + params = BeamSearchParameters(num_beams=10, early_stopping=False) + assert params.strategy == DecodingStrategy.BEAM_SEARCH + assert params.num_beams == 10 + assert not params.early_stopping + + +def test_top_k_parameters(): + params = TopKParameters(top_k=30, temperature=0.7) + assert params.strategy == DecodingStrategy.TOP_K + assert params.top_k == 30 + assert params.temperature == 0.7 + + +def test_top_p_parameters(): + params = TopPParameters(top_p=0.85, temperature=0.9) + assert params.strategy == DecodingStrategy.TOP_P + assert params.top_p == 0.85 + assert params.temperature == 0.9 + + +def test_invalid_decoding_parameters(): + with pytest.raises(ValidationError): + BeamSearchParameters(num_beams=-1, early_stopping=False) # Invalid num_beams + with pytest.raises(ValidationError): + TopKParameters(top_k=-5, temperature=0.7) # Invalid top_k + with pytest.raises(ValidationError): + TopPParameters(top_p=1.5, temperature=0.8) # Invalid top_p + + +def test_document_info_with_greedy(): + doc_info = DocumentInfo( + document_id="doc_001", + prompt="Asses the educational value of the text.", + prompt_lang="en", + raw_data_path="/path/to/raw_data.json", + model="test_model", + decoding_parameters=GreedyParameters(), + ) + assert doc_info.document_id == "doc_001" + assert doc_info.decoding_parameters.strategy == DecodingStrategy.GREEDY + + +def test_document_info_with_top_p(): + doc_info = DocumentInfo( + document_id="doc_002", + prompt="Asses, whether the text contains adult content.", + prompt_lang="en", + raw_data_path="/path/to/raw_data.json", + model="test_model", + decoding_parameters=TopPParameters(top_p=0.8, temperature=0.6), + ) + assert doc_info.document_id == "doc_002" + assert doc_info.decoding_parameters.top_p == 0.8 + assert doc_info.decoding_parameters.temperature == 0.6 + + +def test_statistic_report(): + doc_info = DocumentInfo( + document_id="doc_003", + prompt="Asses, whether the text contains chain of thoughts.", + prompt_lang="en", + raw_data_path="/path/to/raw_data.json", + model="test_model", + decoding_parameters=BeamSearchParameters(num_beams=5, early_stopping=True), + ) + correlation_metrics = CorrelationMetrics( + correlation={ + "average": {"pearson": 0.85, "spearman": 0.82}, + "min": {"pearson": 0.75, "spearman": 0.72}, + } + ) + t_test_results = TTestResults(t_test_p_values={"average": 0.03, "min": 0.05}) + report = StatisticReport( + document_info=doc_info, + correlation_metrics=correlation_metrics, + t_test_results=t_test_results, + ) + + assert report.document_info.document_id == "doc_003" + assert report.correlation_metrics.correlation["average"]["pearson"] == 0.85 + assert report.t_test_results.t_test_p_values["average"] == 0.03 From 3795efc4cacd7f14d75d520e26c96e307f673b80 Mon Sep 17 00:00:00 2001 From: mali-git Date: Thu, 12 Dec 2024 11:58:10 +0100 Subject: [PATCH 5/5] test: fix test --- tests/test_translate.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/test_translate.py b/tests/test_translate.py index 096db068..135e4850 100644 --- a/tests/test_translate.py +++ b/tests/test_translate.py @@ -39,6 +39,8 @@ def test_translate_jsonl_to_multiple_languages( """Test the translate_jsonl_to_multiple_languages method.""" class MockTranslationClient: + name: str = "mock_translation_client" + def translate_text(self, text, source_language, target_language): return mock_translate_text(text, source_language, target_language) @@ -81,7 +83,7 @@ def supported_target_languages(self): # Verify output files for lang in target_languages: - output_file = output_folder / f"input_{lang}.jsonl" + output_file = output_folder / f"input_{lang}_{mock_client.name}.jsonl" assert output_file.exists() with open(output_file, "r", encoding="utf-8") as f: