From 3254c999016c14ac1a0a566cc3211644048ba927 Mon Sep 17 00:00:00 2001 From: Jacob-Chmura Date: Wed, 14 May 2025 09:18:31 -0400 Subject: [PATCH] Add phrase lengths to count validation --- aif_gen/validation/counts.py | 26 ++++++++++++------ test/test_validation/test_count_validation.py | 27 +++++++++++++++++++ 2 files changed, 45 insertions(+), 8 deletions(-) diff --git a/aif_gen/validation/counts.py b/aif_gen/validation/counts.py index a5178093..34b256ea 100644 --- a/aif_gen/validation/counts.py +++ b/aif_gen/validation/counts.py @@ -8,8 +8,8 @@ def count_validation( dataset: Dataset, remove_stop_words: bool = False -) -> List[Dict[str, int]]: - r"""Count the number of 'unique' samples in the dataset. +) -> List[Dict[str, int | float]]: + r"""Count the number of 'unique' samples and average phrase length in the dataset. Args: dataset (Union[ContinualAlignmentDataset, AlignmentDataset]): The dataset to validate. @@ -18,11 +18,14 @@ def count_validation( Returns: List[Dict[str, int]]: For every AligmentDataset, returns a dictionary with the following entries: - 'samples' -> int: The total number of samples in the AlignmentDataset. - 'unique_samples' -> int: The number of unique samples in the AlignmentDataset. - 'unique_prompts' -> int: The number of unique prompts in the AlignmentDataset. - 'unique_chosen' -> int: The number of unique chosen responses in the AlignmentDataset. - 'unique_rejected' -> int: The number of unique rejected responses in the AlignmentDataset. + 'samples' -> int: The total number of samples in the AlignmentDataset. + 'unique_samples' -> int: The number of unique samples in the AlignmentDataset. + 'unique_prompts' -> int: The number of unique prompts in the AlignmentDataset. + 'unique_chosen' -> int: The number of unique chosen responses in the AlignmentDataset. + 'unique_rejected' -> int: The number of unique rejected responses in the AlignmentDataset. + 'avg_prompt_length' -> float: The average length of prompts in the AlignmentDataset. + 'avg_chosen_length' -> float: The average length of chosen responses in the AlignmentDataset. + 'avg_rejected_length' -> float: The average length of rejected responses in the AlignmentDataset. Note: If the input dataset is an AlignmentDataset (non-continual), this function @@ -43,17 +46,24 @@ def count_validation( def _count_validation( dataset: AlignmentDataset, remove_stop_words: bool -) -> Dict[str, int]: +) -> Dict[str, int | float]: samples, prompts, chosen, rejected = set(), set(), set(), set() for sample in dataset.samples: samples.add(rsw(str(sample)) if remove_stop_words else str(sample)) prompts.add(rsw(sample.prompt) if remove_stop_words else sample.prompt) chosen.add(rsw(sample.chosen) if remove_stop_words else sample.chosen) rejected.add(rsw(sample.rejected) if remove_stop_words else sample.rejected) + + def _avg_word_count(strs: List[str]) -> float: + return sum(len(s.split()) for s in strs) / len(strs) if len(strs) else 0 + return { 'sample': dataset.num_samples, 'unique_samples': len(samples), 'unique_prompts': len(prompts), 'unique_chosen': len(chosen), 'unique_rejected': len(rejected), + 'avg_prompt_length': _avg_word_count([x.prompt for x in dataset.samples]), + 'avg_chosen_length': _avg_word_count([x.chosen for x in dataset.samples]), + 'avg_rejected_length': _avg_word_count([x.rejected for x in dataset.samples]), } diff --git a/test/test_validation/test_count_validation.py b/test/test_validation/test_count_validation.py index 0e65a843..992dfaa0 100644 --- a/test/test_validation/test_count_validation.py +++ b/test/test_validation/test_count_validation.py @@ -31,6 +31,9 @@ def test_count_validation_all_unique(): 'unique_prompts': 3, 'unique_chosen': 3, 'unique_rejected': 3, + 'avg_prompt_length': 4.0, + 'avg_chosen_length': 4.0, + 'avg_rejected_length': 4.0, } ] assert count_validation(dataset) == expected_counts @@ -59,6 +62,9 @@ def test_count_validation_all_same_prompts(): 'unique_prompts': 1, 'unique_chosen': 3, 'unique_rejected': 3, + 'avg_prompt_length': 4.0, + 'avg_chosen_length': 4.0, + 'avg_rejected_length': 4.0, } ] assert count_validation(dataset) == expected_counts @@ -87,6 +93,9 @@ def test_count_validation_all_same_responses(): 'unique_prompts': 3, 'unique_chosen': 1, 'unique_rejected': 1, + 'avg_prompt_length': 4.0, + 'avg_chosen_length': 4.0, + 'avg_rejected_length': 4.0, } ] assert count_validation(dataset) == expected_counts @@ -115,6 +124,9 @@ def test_count_validation_all_same_everything(): 'unique_prompts': 1, 'unique_chosen': 1, 'unique_rejected': 1, + 'avg_prompt_length': 4.0, + 'avg_chosen_length': 4.0, + 'avg_rejected_length': 4.0, } ] assert count_validation(dataset) == expected_counts @@ -187,6 +199,9 @@ def test_count_countinual_dataset(): 'unique_prompts': 3, 'unique_chosen': 3, 'unique_rejected': 3, + 'avg_prompt_length': 4.0, + 'avg_chosen_length': 4.0, + 'avg_rejected_length': 4.0, }, { 'sample': 3, @@ -194,6 +209,9 @@ def test_count_countinual_dataset(): 'unique_prompts': 1, 'unique_chosen': 3, 'unique_rejected': 3, + 'avg_prompt_length': 4.0, + 'avg_chosen_length': 4.0, + 'avg_rejected_length': 4.0, }, { 'sample': 3, @@ -201,6 +219,9 @@ def test_count_countinual_dataset(): 'unique_prompts': 3, 'unique_chosen': 1, 'unique_rejected': 1, + 'avg_prompt_length': 4.0, + 'avg_chosen_length': 4.0, + 'avg_rejected_length': 4.0, }, { 'sample': 3, @@ -208,6 +229,9 @@ def test_count_countinual_dataset(): 'unique_prompts': 1, 'unique_chosen': 1, 'unique_rejected': 1, + 'avg_prompt_length': 4.0, + 'avg_chosen_length': 4.0, + 'avg_rejected_length': 4.0, }, ] assert count_validation(dataset) == expected_counts @@ -236,6 +260,9 @@ def test_count_validation_stop_words_removed(): 'unique_prompts': 1, 'unique_chosen': 1, 'unique_rejected': 1, + 'avg_prompt_length': 5.0, + 'avg_chosen_length': 5.0, + 'avg_rejected_length': 5.0, } ] assert count_validation(dataset, remove_stop_words=True) == expected_counts