-
Notifications
You must be signed in to change notification settings - Fork 1
Added support for SalamandraTA 7B instructed #10
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,11 +1,13 @@ | ||
| MOSES_TOKENIZER_DEFAULT_LANG = 'en' | ||
| HELSINKI_NLP = 'Helsinki-NLP' | ||
| MULTIMODALCODE = 'MULTI' | ||
| SUPPORTED_MODEL_TYPES = ['opus', 'opus-big', 'ctranslator2', 'dummy', 'custom', 'm2m100', 'nllb', 'salamandra'] | ||
| SUPPORTED_MODEL_TYPES = ['opus', 'opus-big', 'ctranslator2', 'dummy', 'custom', 'm2m100', 'nllb', 'salamandra','salamandra_instruct'] | ||
| MODEL_TAG_SEPARATOR = '-' | ||
|
|
||
| NLLB_CHECKPOINT_IDS = ["nllb-200-distilled-1.3B", "nllb-200-distilled-600M", "nllb-200-3.3B"] | ||
|
|
||
| M2M100_CHECKPOINT_IDS = ["m2m100_418M", "m2m100_1.2B"] | ||
|
|
||
| SALAMANDRA_CHECKPOINT_IDS = ["salamandraTA-2B"] | ||
| SALAMANDRA_CHECKPOINT_IDS = ["salamandraTA-2B"] | ||
|
|
||
| SALAMANDRA_INSTRUCT_CHECKPOINT_IDS = ["salamandraTA-7b-instruct"] |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -268,7 +268,7 @@ def translator(src_texts, src, tgt): | |
| #pipeline was here | ||
| def salamandra_translator(text, src, tgt, max_length=400): | ||
| prompt = f'[{src}] {text} \n[{tgt}]' | ||
| input_ids = tokenizer(prompt, return_tensors='pt').input_ids | ||
| input_ids = tokenizer(prompt, return_tensors='pt').input_ids.to(model.device) | ||
| output_ids = model.generate( input_ids, max_length=500, num_beams=5 ) | ||
| input_length = input_ids.shape[1] | ||
|
|
||
|
|
@@ -289,10 +289,10 @@ def salamandra_translator(text, src, tgt, max_length=400): | |
| is_tokenizer_loaded = True | ||
|
|
||
| try: | ||
| model = AutoModelForCausalLM.from_pretrained(local_model) | ||
| model = AutoModelForCausalLM.from_pretrained(local_model, device_map="auto") | ||
| except Exception as e: | ||
| print(e) | ||
| model = AutoModelForCausalLM.from_pretrained(remote_model) | ||
| model = AutoModelForCausalLM.from_pretrained(remote_model, device_map="auto") | ||
| model.save_pretrained(local_model) | ||
| finally: | ||
| is_model_loaded = True | ||
|
|
@@ -301,4 +301,77 @@ def salamandra_translator(text, src, tgt, max_length=400): | |
| if is_tokenizer_loaded and is_model_loaded: | ||
| print("Loaded Salamandra model", remote_model) | ||
| return translator | ||
| return None | ||
|
|
||
|
|
||
| def get_batch_salamandra_instruct_translator(salamandra_inst_checkpoint_id:str, lang_map:dict=None) -> Optional[Callable[[str], str]]: | ||
|
|
||
| from datetime import datetime | ||
| from transformers import AutoTokenizer, AutoModelForCausalLM | ||
| import transformers | ||
| import torch | ||
|
|
||
| local_model = os.path.join(os.getenv('MODELS_ROOT'), salamandra_inst_checkpoint_id) | ||
| remote_model = salamandra_inst_checkpoint_id | ||
|
|
||
| is_model_loaded, is_tokenizer_loaded = False, False | ||
|
|
||
| def translator(src_texts, src, tgt): | ||
| print(lang_map) | ||
| if lang_map: | ||
| src = lang_map.get(src) if src in lang_map else src | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. question: do we want ot use the src value provided if not in the language map?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I believe this part comes from the bilingual models. Without the source the API cannot choose the model to use. |
||
| tgt = lang_map.get(tgt) if tgt in lang_map else tgt | ||
|
|
||
| if not src_texts: | ||
| return '' | ||
| else: | ||
| #pipeline was here | ||
| def salamandra_inst_translator(text, src, tgt, max_length=400): | ||
|
|
||
| prompt = f"Translate the following text from {src} into {tgt}.\n{src}: {text} \n{tgt}:" | ||
| message = [ { "role": "user", "content": prompt } ] | ||
| date_string = datetime.today().strftime('%Y-%m-%d') | ||
|
|
||
| prompt = tokenizer.apply_chat_template( | ||
| message, | ||
| tokenize=False, | ||
| add_generation_prompt=True, | ||
| date_string=date_string | ||
| ) | ||
| inputs = tokenizer.encode(prompt, add_special_tokens=False, return_tensors="pt") | ||
| input_length = inputs.shape[1] | ||
| outputs = model.generate(input_ids=inputs.to(model.device), | ||
| max_new_tokens=400, | ||
| early_stopping=True, | ||
| num_beams=5) | ||
|
|
||
| generated_text = tokenizer.decode(outputs[0, input_length:], skip_special_tokens=True) | ||
| return generated_text | ||
|
|
||
| return [salamandra_inst_translator(text, src, tgt, max_length=400) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. question: what will be the src value if not langmap?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. By default it would be "es" (Spanish), if I'm not wrong. |
||
| for text in src_texts] | ||
|
|
||
|
|
||
| try: | ||
| tokenizer = AutoTokenizer.from_pretrained(local_model) | ||
| except Exception as e: | ||
| print(e) | ||
| tokenizer = AutoTokenizer.from_pretrained(remote_model) | ||
| tokenizer.save_pretrained(local_model) | ||
| finally: | ||
| is_tokenizer_loaded = True | ||
|
|
||
| try: | ||
| model = AutoModelForCausalLM.from_pretrained(local_model, device_map="auto", torch_dtype=torch.bfloat16) | ||
| except Exception as e: | ||
| print(e) | ||
| model = AutoModelForCausalLM.from_pretrained(remote_model, device_map="auto", torch_dtype=torch.bfloat16) | ||
| model.save_pretrained(local_model) | ||
| finally: | ||
| is_model_loaded = True | ||
|
|
||
|
|
||
| if is_tokenizer_loaded and is_model_loaded: | ||
| print("Loaded Salamandra Instructed model", remote_model) | ||
| return translator | ||
| return None | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: remove print