Skip to content

Bug about PromptEHR #19

@xansar

Description

@xansar

When I tried to load PromptEHR from pretrained, a bug occurred:

AttributeError                            Traceback (most recent call last)
Input In [9], in <cell line: 5>()
      3 vocs = data['voc']
      4 model = PromptEHR()
----> 5 model.from_pretrained()

File ~/miniconda3/envs/trial/lib/python3.9/site-packages/pytrial/tasks/trial_simulation/sequence/promptehr.py:222, in PromptEHR.from_pretrained(self, input_dir)
    211 def from_pretrained(self, input_dir='./simulation/pretrained_promptEHR'):
    212     '''
    213     Load pretrained PromptEHR model and make patient EHRs generation.
    214     Pretrained model was learned from MIMIC-III patient sequence data.
   (...)
    220         to this folder.
    221     '''
--> 222     self.model.from_pretrained(input_dir=input_dir)
    223     self.config.update(self.model.config)

File ~/miniconda3/envs/trial/lib/python3.9/site-packages/promptehr/promptehr.py:359, in PromptEHR.from_pretrained(self, input_dir)
    356     print(f'Download pretrained PromptEHR model, save to {input_dir}.')
    358 print('Load pretrained PromptEHR model from', input_dir)
--> 359 self.load_model(input_dir)

File ~/miniconda3/envs/trial/lib/python3.9/site-packages/promptehr/promptehr.py:298, in PromptEHR.load_model(self, checkpoint)
    295 self._load_tokenizer(data_tokenizer_file, model_tokenizer_file)
    297 # load configuration
--> 298 self.configuration = EHRBartConfig(self.data_tokenizer, self.model_tokenizer, n_num_feature=self.config['n_num_feature'], cat_cardinalities=self.config['cat_cardinalities'])
    299 self.configuration.from_pretrained(checkpoint)
    301 # build model

File ~/miniconda3/envs/trial/lib/python3.9/site-packages/promptehr/modeling_config.py:24, in EHRBartConfig(data_tokenizer, model_tokenizer, **kwargs)
     22 bart_config = BartConfig.from_pretrained('facebook/bart-base')
     23 kwargs.update(model_tokenizer.get_num_tokens)
---> 24 kwargs['data_tokenizer_num_vocab'] = len(data_tokenizer)
     25 if 'd_prompt_hidden' not in kwargs:
     26     kwargs['d_prompt_hidden'] = 128

File ~/miniconda3/envs/trial/lib/python3.9/site-packages/transformers/tokenization_utils.py:431, in PreTrainedTokenizer.__len__(self)
    426 def __len__(self):
    427     """
    428     Size of the full vocabulary with the added tokens. Counts the `keys` and not the `values` because otherwise if
    429     there is a hole in the vocab, we will add tokenizers at a wrong index.
    430     """
--> 431     return len(set(self.get_vocab().keys()))

File ~/miniconda3/envs/trial/lib/python3.9/site-packages/transformers/models/bart/tokenization_bart.py:243, in BartTokenizer.get_vocab(self)
    242 def get_vocab(self):
--> 243     return dict(self.encoder, **self.added_tokens_encoder)

File ~/miniconda3/envs/trial/lib/python3.9/site-packages/transformers/tokenization_utils.py:391, in PreTrainedTokenizer.added_tokens_encoder(self)
    385 @property
    386 def added_tokens_encoder(self) -> Dict[str, int]:
    387     """
    388     Returns the sorted mapping from string to index. The added tokens encoder is cached for performance
    389     optimisation in `self._added_tokens_encoder` for the slow tokenizers.
    390     """
--> 391     return {k.content: v for v, k in sorted(self._added_tokens_decoder.items(), key=lambda item: item[0])}

AttributeError: 'DataTokenizer' object has no attribute '_added_tokens_decoder'

——————————————————————————
My codes are:

from pytrial.tasks.trial_simulation.data import SequencePatient
from pytrial.data.demo_data import load_synthetic_ehr_sequence
data = load_synthetic_ehr_sequence()

train_data = SequencePatient(
    data={
        'v': data['visit'],
        'y': data['y'],
        'x': data['feature'],
        },
    metadata={
        'visit': {'mode': 'dense'},
        'label': {'mode': 'tensor'},
        'voc': data['voc'],
        'max_visit': 20,
        'n_num_feature': data['n_num_feature'],
        'cat_cardinalities': data['cat_cardinalities'],
    }
)

from pytrial.tasks.trial_simulation.sequence import PromptEHR

vocs = data['voc']
model = PromptEHR()
model.from_pretrained()

I can directly load BartTokenizer successfully:

from transformers import BartTokenizer
tokenizer = BartTokenizer.from_pretrained('facebook/bart-base')
tokenizer

BartTokenizer(name_or_path='facebook/bart-base', vocab_size=50265, model_max_length=1024, is_fast=False, padding_side='right', truncation_side='right', special_tokens={'bos_token': '<s>', 'eos_token': '</s>', 'unk_token': '<unk>', 'sep_token': '</s>', 'pad_token': '<pad>', 'cls_token': '<s>', 'mask_token': '<mask>'}, clean_up_tokenization_spaces=True),  added_tokens_decoder={
	0: AddedToken("<s>", rstrip=False, lstrip=False, single_word=False, normalized=True, special=True),
	1: AddedToken("<pad>", rstrip=False, lstrip=False, single_word=False, normalized=True, special=True),
	2: AddedToken("</s>", rstrip=False, lstrip=False, single_word=False, normalized=True, special=True),
	3: AddedToken("<unk>", rstrip=False, lstrip=False, single_word=False, normalized=True, special=True),
	50264: AddedToken("<mask>", rstrip=False, lstrip=True, single_word=False, normalized=True, special=True),
}

tokenizer.added_tokens_decoder

{0: AddedToken("<s>", rstrip=False, lstrip=False, single_word=False, normalized=True, special=True),
 1: AddedToken("<pad>", rstrip=False, lstrip=False, single_word=False, normalized=True, special=True),
 2: AddedToken("</s>", rstrip=False, lstrip=False, single_word=False, normalized=True, special=True),
 3: AddedToken("<unk>", rstrip=False, lstrip=False, single_word=False, normalized=True, special=True),
 50264: AddedToken("<mask>", rstrip=False, lstrip=True, single_word=False, normalized=True, special=True)}

Could you please help me to fix this bug?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions