Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions app/constants.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
MOSES_TOKENIZER_DEFAULT_LANG = 'en'
HELSINKI_NLP = 'Helsinki-NLP'
MULTIMODALCODE = 'MULTI'
SUPPORTED_MODEL_TYPES = ['opus', 'opus-big', 'ctranslator2', 'dummy', 'custom', 'm2m100', 'nllb', 'salamandra']
SUPPORTED_MODEL_TYPES = ['opus', 'opus-big', 'ctranslator2', 'dummy', 'custom', 'm2m100', 'nllb', 'salamandra','salamandra_instruct']
MODEL_TAG_SEPARATOR = '-'

NLLB_CHECKPOINT_IDS = ["nllb-200-distilled-1.3B", "nllb-200-distilled-600M", "nllb-200-3.3B"]

M2M100_CHECKPOINT_IDS = ["m2m100_418M", "m2m100_1.2B"]

SALAMANDRA_CHECKPOINT_IDS = ["salamandraTA-2B"]
SALAMANDRA_CHECKPOINT_IDS = ["salamandraTA-2B"]

SALAMANDRA_INSTRUCT_CHECKPOINT_IDS = ["salamandraTA-7b-instruct"]
3 changes: 1 addition & 2 deletions app/helpers/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def _is_valid_model_type(self, model_type: str) -> bool:
def _load_models(self, load_all, models_to_load) -> None:
for model_config in self.config_data['models']:
_, _, model_id = self._get_ser_tgt_model_id(model_config)

if not load_all and model_id not in models_to_load:
continue

Expand Down Expand Up @@ -291,7 +291,6 @@ def _load_languages_list(self) -> None:
self.languages_list[source][target].append(model_id)
self.pair_to_model_id_map[model_id] = main_model_id

self._log_info(f'Languages list: {self.languages_list}')

def _lookup_pair_in_languages_list(self, src, tgt, alt=None):
if src in self.languages_list:
Expand Down
3 changes: 3 additions & 0 deletions app/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,6 @@
#Specify which SALAMANDRA model to load here by default (if not specified in config as checkpoint_id)
DEFAULT_SALAMANDRA_MODEL_TYPE = "salamandraTA-2B"


#Specify which SALAMANDRA model to load here by default (if not specified in config as checkpoint_id)
DEFAULT_SALAMANDRA_INSTRUCT_MODEL_TYPE = "salamandraTA-7b-instruct"
27 changes: 25 additions & 2 deletions app/utils/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
get_batch_nllbtranslator,
get_batch_m2m100translator,
get_batch_salamandratranslator,
get_batch_salamandra_instruct_translator,
dummy_translator,
get_custom_translator,
)
Expand All @@ -32,7 +33,7 @@
)

from app.settings import DEFAULT_NLLB_MODEL_TYPE, DEFAULT_M2M100_MODEL_TYPE, DEFAULT_SALAMANDRA_MODEL_TYPE
from app.constants import NLLB_CHECKPOINT_IDS, M2M100_CHECKPOINT_IDS, SALAMANDRA_CHECKPOINT_IDS
from app.constants import NLLB_CHECKPOINT_IDS, M2M100_CHECKPOINT_IDS, SALAMANDRA_CHECKPOINT_IDS, SALAMANDRA_INSTRUCT_CHECKPOINT_IDS

def load_model_sentence_segmenter(
model: Dict,
Expand Down Expand Up @@ -275,7 +276,29 @@ def load_model_translator(
f'Failed to load salamandra-huggingface model for {model_id}. Skipping load.'
)
raise ModelLoadingException


elif model_config['model_type'] == 'salamandra_instruct':
salamandra_checkpoint_id = model_config.get('checkpoint_id') if 'checkpoint_id' in model_config else DEFAULT_SALAMANDRA_MODEL_TYPE
if len(model_config.get('checkpoint_id').split('/')) == 1:
if salamandra_checkpoint_id not in SALAMANDRA_INSTRUCT_CHECKPOINT_IDS:
warn(
f'No checkpoint exists for base salamandra model: BSC-LT/{salamandra_checkpoint_id}. Skipping load.'
)
raise ModelLoadingException
salamandra_checkpoint_id = 'BSC-LT/' + salamandra_checkpoint_id
warn(f'Full model id: {salamandra_checkpoint_id}')

translator = get_batch_salamandra_instruct_translator(salamandra_checkpoint_id, lang_map=model_config.get('lang_code_map'))
if translator:
model['translator'] = translator
msg += '-salamandra-huggingface-' + salamandra_checkpoint_id
else:
warn(
f'Failed to load salamandra-huggingface model for {model_id}. Skipping load.'
)
raise ModelLoadingException


elif model_config['model_type'] == 'dummy':
msg += '-dummy'
model['translator'] = dummy_translator
Expand Down
79 changes: 76 additions & 3 deletions app/utils/translators.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ def translator(src_texts, src, tgt):
#pipeline was here
def salamandra_translator(text, src, tgt, max_length=400):
prompt = f'[{src}] {text} \n[{tgt}]'
input_ids = tokenizer(prompt, return_tensors='pt').input_ids
input_ids = tokenizer(prompt, return_tensors='pt').input_ids.to(model.device)
output_ids = model.generate( input_ids, max_length=500, num_beams=5 )
input_length = input_ids.shape[1]

Expand All @@ -289,10 +289,10 @@ def salamandra_translator(text, src, tgt, max_length=400):
is_tokenizer_loaded = True

try:
model = AutoModelForCausalLM.from_pretrained(local_model)
model = AutoModelForCausalLM.from_pretrained(local_model, device_map="auto")
except Exception as e:
print(e)
model = AutoModelForCausalLM.from_pretrained(remote_model)
model = AutoModelForCausalLM.from_pretrained(remote_model, device_map="auto")
model.save_pretrained(local_model)
finally:
is_model_loaded = True
Expand All @@ -301,4 +301,77 @@ def salamandra_translator(text, src, tgt, max_length=400):
if is_tokenizer_loaded and is_model_loaded:
print("Loaded Salamandra model", remote_model)
return translator
return None


def get_batch_salamandra_instruct_translator(salamandra_inst_checkpoint_id:str, lang_map:dict=None) -> Optional[Callable[[str], str]]:

from datetime import datetime
from transformers import AutoTokenizer, AutoModelForCausalLM
import transformers
import torch

local_model = os.path.join(os.getenv('MODELS_ROOT'), salamandra_inst_checkpoint_id)
remote_model = salamandra_inst_checkpoint_id

is_model_loaded, is_tokenizer_loaded = False, False

def translator(src_texts, src, tgt):
print(lang_map)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: remove print

if lang_map:
src = lang_map.get(src) if src in lang_map else src

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

question: do we want ot use the src value provided if not in the language map?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe this part comes from the bilingual models. Without the source the API cannot choose the model to use.

tgt = lang_map.get(tgt) if tgt in lang_map else tgt

if not src_texts:
return ''
else:
#pipeline was here
def salamandra_inst_translator(text, src, tgt, max_length=400):

prompt = f"Translate the following text from {src} into {tgt}.\n{src}: {text} \n{tgt}:"
message = [ { "role": "user", "content": prompt } ]
date_string = datetime.today().strftime('%Y-%m-%d')

prompt = tokenizer.apply_chat_template(
message,
tokenize=False,
add_generation_prompt=True,
date_string=date_string
)
inputs = tokenizer.encode(prompt, add_special_tokens=False, return_tensors="pt")
input_length = inputs.shape[1]
outputs = model.generate(input_ids=inputs.to(model.device),
max_new_tokens=400,
early_stopping=True,
num_beams=5)

generated_text = tokenizer.decode(outputs[0, input_length:], skip_special_tokens=True)
return generated_text

return [salamandra_inst_translator(text, src, tgt, max_length=400)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

question: what will be the src value if not langmap?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By default it would be "es" (Spanish), if I'm not wrong.

for text in src_texts]


try:
tokenizer = AutoTokenizer.from_pretrained(local_model)
except Exception as e:
print(e)
tokenizer = AutoTokenizer.from_pretrained(remote_model)
tokenizer.save_pretrained(local_model)
finally:
is_tokenizer_loaded = True

try:
model = AutoModelForCausalLM.from_pretrained(local_model, device_map="auto", torch_dtype=torch.bfloat16)
except Exception as e:
print(e)
model = AutoModelForCausalLM.from_pretrained(remote_model, device_map="auto", torch_dtype=torch.bfloat16)
model.save_pretrained(local_model)
finally:
is_model_loaded = True


if is_tokenizer_loaded and is_model_loaded:
print("Loaded Salamandra Instructed model", remote_model)
return translator
return None
Loading