From 1db784e3b1aad749d42f1efde57d9364099bf850 Mon Sep 17 00:00:00 2001 From: Jordan Clive Date: Fri, 28 Apr 2023 12:51:33 +0100 Subject: [PATCH 1/2] Update metadata_utils.py fairer without metadata. If you keep the metadata_probability the same, the idea is to give the same context in each example in a without_metadata run. --- bsmetadata/metadata_utils.py | 33 ++++++++++++++++++++++----------- 1 file changed, 22 insertions(+), 11 deletions(-) diff --git a/bsmetadata/metadata_utils.py b/bsmetadata/metadata_utils.py index f3f942e7..ae8afba9 100644 --- a/bsmetadata/metadata_utils.py +++ b/bsmetadata/metadata_utils.py @@ -48,7 +48,7 @@ class BasicMetadata: def add_metadata_and_chunk_examples( - examples: Dict[str, List], tokenizer: PreTrainedTokenizerFast, cfg: MetadataConfig + examples: Dict[str, List], tokenizer: PreTrainedTokenizerFast, cfg: MetadataConfig, without_metadata: bool = False ) -> Dict[str, List]: """Adds metadata to the provided input examples, encodes them and groups them in chunks of size `cfg.max_seq_len`. @@ -124,16 +124,27 @@ def is_metadata(idx: int) -> bool: for text_chunk_encoded, chunk_metadata_mask in chunks( max_text_len, text_with_local_metadata_encoded.input_ids, token_level_metadata_mask ): - total_len = prefix_len + len(text_chunk_encoded) - padding_len = max_text_len - len(text_chunk_encoded) - - input_ids = metadata_prefix_encoded + text_chunk_encoded + [tokenizer.eos_token_id] * padding_len - attention_mask = [1] * total_len + [0] * padding_len - metadata_mask = [1] * prefix_len + [int(x) for x in chunk_metadata_mask] + [0] * padding_len - - linearized_examples["input_ids"].append(input_ids) - linearized_examples["attention_mask"].append(attention_mask) - linearized_examples["metadata_mask"].append(metadata_mask) + if without_metadata: + total_len = len(text_chunk_encoded) + padding_len = cfg.max_seq_len - len(text_chunk_encoded) + attention_mask = [1] * total_len + [0] * padding_len + text_chunk_encoded = text_chunk_encoded + [tokenizer.eos_token_id] * padding_len + input_ids = text_chunk_encoded + [tokenizer.eos_token_id] * padding_len + metadata_mask = [0] * total_len + linearized_examples["input_ids"].append(input_ids) + linearized_examples["attention_mask"].append(attention_mask) + linearized_examples["metadata_mask"].append(metadata_mask) + else: + total_len = prefix_len + len(text_chunk_encoded) + padding_len = max_text_len - len(text_chunk_encoded) + + input_ids = metadata_prefix_encoded + text_chunk_encoded + [tokenizer.eos_token_id] * padding_len + attention_mask = [1] * total_len + [0] * padding_len + metadata_mask = [1] * prefix_len + [int(x) for x in chunk_metadata_mask] + [0] * padding_len + + linearized_examples["input_ids"].append(input_ids) + linearized_examples["attention_mask"].append(attention_mask) + linearized_examples["metadata_mask"].append(metadata_mask) return linearized_examples From fb6c139459c0f389d0223c2924cec1342882fa8b Mon Sep 17 00:00:00 2001 From: Jordan Clive Date: Fri, 28 Apr 2023 14:25:57 +0100 Subject: [PATCH 2/2] Update metadata_utils.py fix duplication. --- bsmetadata/metadata_utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/bsmetadata/metadata_utils.py b/bsmetadata/metadata_utils.py index ae8afba9..77024e10 100644 --- a/bsmetadata/metadata_utils.py +++ b/bsmetadata/metadata_utils.py @@ -128,7 +128,6 @@ def is_metadata(idx: int) -> bool: total_len = len(text_chunk_encoded) padding_len = cfg.max_seq_len - len(text_chunk_encoded) attention_mask = [1] * total_len + [0] * padding_len - text_chunk_encoded = text_chunk_encoded + [tokenizer.eos_token_id] * padding_len input_ids = text_chunk_encoded + [tokenizer.eos_token_id] * padding_len metadata_mask = [0] * total_len linearized_examples["input_ids"].append(input_ids)