Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,6 @@ jobs:
- name: Run tests
run: |
export HUGGINGFACE_HUB_CACHE=/app/models
export HF_HUB_ENABLE_HF_TRANSFER=1
export PORT=80
pytest


Expand Down
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@ stop:


act-run-tests:
gh act -j run-tests -W '.github/workflows/tests.yml'
gh act -j test -W '.github/workflows/tests.yml'
4 changes: 2 additions & 2 deletions app/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@

from contextlib import asynccontextmanager

def create_app(load_all_models=False, models_to_load=[]) -> FastAPI:
def create_app(models_to_load=[]) -> FastAPI:
@asynccontextmanager
async def lifespan(application: FastAPI):
config = Config(load_all_models=load_all_models, models_to_load=models_to_load)
config = Config(models_to_load=models_to_load)
yield

app = FastAPI(lifespan=lifespan)
Expand Down
8 changes: 3 additions & 5 deletions app/helpers/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ def __init__(
self,
config_file: Optional[str] = None,
config_data: Optional[Dict] = None,
load_all_models: bool = False,
models_to_load:list = []
):
self.loaded_models: Dict = {}
Expand All @@ -33,7 +32,6 @@ def __init__(
self.pair_to_model_id_map: Dict = {}
self.config_data: Dict = config_data or {}
self.config_file: str = config_file or CONFIG_JSON_PATH
self.load_all_models: bool = load_all_models

self.warnings: List[str] = []
self.messages: List[str] = []
Expand All @@ -43,7 +41,7 @@ def __init__(
self._validate_models()

self._load_language_codes()
self._load_models(load_all_models, models_to_load)
self._load_models(models_to_load)
self._load_languages_list()

def map_lang_to_closest(self, lang: str) -> str:
Expand Down Expand Up @@ -145,11 +143,11 @@ def _is_valid_model_type(self, model_type: str) -> bool:
return False
return True

def _load_models(self, load_all, models_to_load) -> None:
def _load_models(self, models_to_load) -> None:
for model_config in self.config_data['models']:
_, _, model_id = self._get_ser_tgt_model_id(model_config)

if not load_all and model_id not in models_to_load:
if 'all' not in models_to_load and model_id not in models_to_load:
continue

# CONFIG CHECKS
Expand Down
53 changes: 28 additions & 25 deletions app/tests/api/v1/test_api_translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,40 +15,41 @@ def setup_before_each_test(self):
def test_list_languages(self):
with TestClient(self.app) as client:
response = client.get(url=self.get_endpoint('/'))
assert response.status_code == status.HTTP_200_OK
content = response.json()
assert content['models'] == {'ca': {'es': ['ca-es']}, 'es': {'ca': ['es-ca']}}
assert content['languages'] == {
"es": "Spanish",
"ca": "Catalan",
"en": "English",
"fr": "French",
"de": "German",
"it": "Italian",
"pt": "Portuguese"
}
assert response.status_code == status.HTTP_200_OK
content = response.json()
assert content['models'] == {'ca': {'es': ['ca-es']}, 'es': {'ca': ['es-ca']}}
assert content['languages'] == {
"es": "Spanish",
"ca": "Catalan",
"en": "English",
"fr": "French",
"de": "German",
"it": "Italian",
"pt": "Portuguese"
}

def test_translate_text_valid_code(self):
options = {
'src': 'es',
'tgt': 'ca',
'text': '¿Cómo estás?',
}
expected_translation = 'Com estàs?'
with TestClient(self.app) as client:
options = {
'src': 'es',
'tgt': 'ca',
'text': '¿Cómo estás?',
}
expected_translation = 'Com estàs?'
response = client.post(self.get_endpoint('/'), content=json.dumps(options))
assert response.status_code == status.HTTP_200_OK
content = response.json()
assert content['translation'] == expected_translation
assert response.status_code == status.HTTP_200_OK
content = response.json()
assert content['translation'] == expected_translation

def test_translate_text_invalid_code(self):
options = {
'src': 'sfs',
'tgt': 'xyz',
'text': 'Hello there, how are you doing?',
}
response = self.client.post(self.get_endpoint('/'), content=json.dumps(options))
assert response.status_code == status.HTTP_406_NOT_ACCEPTABLE
with TestClient(self.app) as client:
response = client.post(self.get_endpoint('/'), content=json.dumps(options))
assert response.status_code == status.HTTP_406_NOT_ACCEPTABLE

def test_batch_translate_text_valid_code(self):
options = {
Expand All @@ -57,7 +58,8 @@ def test_batch_translate_text_valid_code(self):
'texts': ['Hola, ¿Cómo te llamas?', '¿Cómo estás?'],
}
expected_translations = ['Hola, com et dius?', 'Com estàs?']
response = self.client.post(url=self.get_endpoint('/batch'), content=json.dumps(options))
with TestClient(self.app) as client:
response = client.post(url=self.get_endpoint('/batch'), content=json.dumps(options))
assert response.status_code == status.HTTP_200_OK
content = response.json()
assert content['translation'] == expected_translations
Expand All @@ -68,5 +70,6 @@ def test_batch_translate_text_invalid_code(self):
'tgt': 'xyz',
'texts': ['Hola, ¿Cómo te llamas?', '¿Cómo estás?'],
}
response = self.client.post(url=self.get_endpoint('/batch'), content=json.dumps(options))
with TestClient(self.app) as client:
response = client.post(url=self.get_endpoint('/batch'), content=json.dumps(options))
assert response.status_code == status.HTTP_406_NOT_ACCEPTABLE
5 changes: 2 additions & 3 deletions app/tests/base_test_case.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
from fastapi.testclient import TestClient
from app import create_app
from main import app
import os

class BaseTestCase:

def setup(self):
os.environ['MODELS_ROOT'] = "./models"
self.app = create_app(models_to_load=["es-ca", "ca-es"])
self.app = app
self.client = TestClient(self.app)


Expand Down
2 changes: 1 addition & 1 deletion app/tests/test_load_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def test_load_models_with_warnings(self):
}]
}

self.config = Config(config_data=self.config_data, load_all_models=True)
self.config = Config(config_data=self.config_data, models_to_load=["all"])
# languages
assert self.config.language_codes == {
"es": "Spanish",
Expand Down
3 changes: 3 additions & 0 deletions app/tests/test_translations.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
"""
import json
import pytest
from app.utils.translate import translate_text
Expand All @@ -23,9 +24,11 @@ def test_translate_text_es_ca(self):
assert translation == expected_translation



def test_translate_text_ca_es(self):
model_id = get_model_id('ca', 'es')
text = 'Hola, com estàs?'
expected_translation = 'Hola, ¿cómo estás?'
translation = translate_text(model_id, text, 'ca', 'es')
assert translation == expected_translation
"""
1 change: 0 additions & 1 deletion app/views/v1/translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@

def fetch_model_data_from_request(request):
config = Config()

src = config.map_lang_to_closest(request.src)
tgt = config.map_lang_to_closest(request.tgt)
use_multi = True if request.use_multi == 'True' else False
Expand Down
26 changes: 16 additions & 10 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,24 @@
import uvicorn
from app import create_app

if __name__ == "__main__":

def get_arguments():
parser = argparse.ArgumentParser(description="An API designed to provide translation services for text between different languages.")
parser.add_argument("-m", "--models", type=str, default="./models", help="Directory path of models", required=False)
parser.add_argument("-l", "--load", type=str, nargs="+", help="Option to load models, if it contains 'all' it will download all models", default=["es-ca", "ca-es"])
parser.add_argument("-r", "--reload", type=bool, help="Reload api on changes", action=argparse.BooleanOptionalAction)
parser.add_argument("-w", "--workers", type=int, help="Number of workers to run the api", default=None)
parser.add_argument("--host", type=str, help="Host to run the app", default="0.0.0.0")
parser.add_argument("--port", type=int, help="Port to run the app", default=8000)
parser.add_argument("--logs", type=str, help="Logging config file", default="logging.yml")

return parser

args = parser.parse_args()
os.environ['MODELS_ROOT'] = args.models
models_to_load = args.load
if 'all' in models_to_load:
load_all_models = True
else:
load_all_models = False

app = create_app(load_all_models, models_to_load)
uvicorn.run(app, host="0.0.0.0", port=8000, log_config = "logging.yml")
args = get_arguments().parse_args()
os.environ['MODELS_ROOT'] = args.models

app = create_app(args.load)

if __name__ == "__main__":
uvicorn.run("main:app", host=args.host, port=args.port, log_config=args.logs, reload=args.reload, workers=args.workers)
4 changes: 1 addition & 3 deletions run_local.sh
Original file line number Diff line number Diff line change
@@ -1,3 +1 @@
export MT_API_CONFIG=config.json
export MODELS_ROOT=./models
python main.py --models './models' --load all
python main.py
3 changes: 0 additions & 3 deletions run_test.sh

This file was deleted.