Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 18 additions & 8 deletions aif_gen/validation/counts.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@

def count_validation(
dataset: Dataset, remove_stop_words: bool = False
) -> List[Dict[str, int]]:
r"""Count the number of 'unique' samples in the dataset.
) -> List[Dict[str, int | float]]:
r"""Count the number of 'unique' samples and average phrase length in the dataset.

Args:
dataset (Union[ContinualAlignmentDataset, AlignmentDataset]): The dataset to validate.
Expand All @@ -18,11 +18,14 @@ def count_validation(
Returns:
List[Dict[str, int]]: For every AligmentDataset, returns a dictionary with the following entries:

'samples' -> int: The total number of samples in the AlignmentDataset.
'unique_samples' -> int: The number of unique samples in the AlignmentDataset.
'unique_prompts' -> int: The number of unique prompts in the AlignmentDataset.
'unique_chosen' -> int: The number of unique chosen responses in the AlignmentDataset.
'unique_rejected' -> int: The number of unique rejected responses in the AlignmentDataset.
'samples' -> int: The total number of samples in the AlignmentDataset.
'unique_samples' -> int: The number of unique samples in the AlignmentDataset.
'unique_prompts' -> int: The number of unique prompts in the AlignmentDataset.
'unique_chosen' -> int: The number of unique chosen responses in the AlignmentDataset.
'unique_rejected' -> int: The number of unique rejected responses in the AlignmentDataset.
'avg_prompt_length' -> float: The average length of prompts in the AlignmentDataset.
'avg_chosen_length' -> float: The average length of chosen responses in the AlignmentDataset.
'avg_rejected_length' -> float: The average length of rejected responses in the AlignmentDataset.

Note:
If the input dataset is an AlignmentDataset (non-continual), this function
Expand All @@ -43,17 +46,24 @@ def count_validation(

def _count_validation(
dataset: AlignmentDataset, remove_stop_words: bool
) -> Dict[str, int]:
) -> Dict[str, int | float]:
samples, prompts, chosen, rejected = set(), set(), set(), set()
for sample in dataset.samples:
samples.add(rsw(str(sample)) if remove_stop_words else str(sample))
prompts.add(rsw(sample.prompt) if remove_stop_words else sample.prompt)
chosen.add(rsw(sample.chosen) if remove_stop_words else sample.chosen)
rejected.add(rsw(sample.rejected) if remove_stop_words else sample.rejected)

def _avg_word_count(strs: List[str]) -> float:
return sum(len(s.split()) for s in strs) / len(strs) if len(strs) else 0

return {
'sample': dataset.num_samples,
'unique_samples': len(samples),
'unique_prompts': len(prompts),
'unique_chosen': len(chosen),
'unique_rejected': len(rejected),
'avg_prompt_length': _avg_word_count([x.prompt for x in dataset.samples]),
'avg_chosen_length': _avg_word_count([x.chosen for x in dataset.samples]),
'avg_rejected_length': _avg_word_count([x.rejected for x in dataset.samples]),
}
27 changes: 27 additions & 0 deletions test/test_validation/test_count_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@ def test_count_validation_all_unique():
'unique_prompts': 3,
'unique_chosen': 3,
'unique_rejected': 3,
'avg_prompt_length': 4.0,
'avg_chosen_length': 4.0,
'avg_rejected_length': 4.0,
}
]
assert count_validation(dataset) == expected_counts
Expand Down Expand Up @@ -59,6 +62,9 @@ def test_count_validation_all_same_prompts():
'unique_prompts': 1,
'unique_chosen': 3,
'unique_rejected': 3,
'avg_prompt_length': 4.0,
'avg_chosen_length': 4.0,
'avg_rejected_length': 4.0,
}
]
assert count_validation(dataset) == expected_counts
Expand Down Expand Up @@ -87,6 +93,9 @@ def test_count_validation_all_same_responses():
'unique_prompts': 3,
'unique_chosen': 1,
'unique_rejected': 1,
'avg_prompt_length': 4.0,
'avg_chosen_length': 4.0,
'avg_rejected_length': 4.0,
}
]
assert count_validation(dataset) == expected_counts
Expand Down Expand Up @@ -115,6 +124,9 @@ def test_count_validation_all_same_everything():
'unique_prompts': 1,
'unique_chosen': 1,
'unique_rejected': 1,
'avg_prompt_length': 4.0,
'avg_chosen_length': 4.0,
'avg_rejected_length': 4.0,
}
]
assert count_validation(dataset) == expected_counts
Expand Down Expand Up @@ -187,27 +199,39 @@ def test_count_countinual_dataset():
'unique_prompts': 3,
'unique_chosen': 3,
'unique_rejected': 3,
'avg_prompt_length': 4.0,
'avg_chosen_length': 4.0,
'avg_rejected_length': 4.0,
},
{
'sample': 3,
'unique_samples': 3,
'unique_prompts': 1,
'unique_chosen': 3,
'unique_rejected': 3,
'avg_prompt_length': 4.0,
'avg_chosen_length': 4.0,
'avg_rejected_length': 4.0,
},
{
'sample': 3,
'unique_samples': 3,
'unique_prompts': 3,
'unique_chosen': 1,
'unique_rejected': 1,
'avg_prompt_length': 4.0,
'avg_chosen_length': 4.0,
'avg_rejected_length': 4.0,
},
{
'sample': 3,
'unique_samples': 1,
'unique_prompts': 1,
'unique_chosen': 1,
'unique_rejected': 1,
'avg_prompt_length': 4.0,
'avg_chosen_length': 4.0,
'avg_rejected_length': 4.0,
},
]
assert count_validation(dataset) == expected_counts
Expand Down Expand Up @@ -236,6 +260,9 @@ def test_count_validation_stop_words_removed():
'unique_prompts': 1,
'unique_chosen': 1,
'unique_rejected': 1,
'avg_prompt_length': 5.0,
'avg_chosen_length': 5.0,
'avg_rejected_length': 5.0,
}
]
assert count_validation(dataset, remove_stop_words=True) == expected_counts
Loading