diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index de9dfab..fb83380 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -38,8 +38,6 @@ jobs: - name: Run tests run: | export HUGGINGFACE_HUB_CACHE=/app/models - export HF_HUB_ENABLE_HF_TRANSFER=1 - export PORT=80 pytest diff --git a/Makefile b/Makefile index fd7a4a9..aa6594a 100644 --- a/Makefile +++ b/Makefile @@ -11,4 +11,4 @@ stop: act-run-tests: - gh act -j run-tests -W '.github/workflows/tests.yml' \ No newline at end of file + gh act -j test -W '.github/workflows/tests.yml' diff --git a/app/__init__.py b/app/__init__.py index 801258c..35189f2 100644 --- a/app/__init__.py +++ b/app/__init__.py @@ -4,10 +4,10 @@ from contextlib import asynccontextmanager -def create_app(load_all_models=False, models_to_load=[]) -> FastAPI: +def create_app(models_to_load=[]) -> FastAPI: @asynccontextmanager async def lifespan(application: FastAPI): - config = Config(load_all_models=load_all_models, models_to_load=models_to_load) + config = Config(models_to_load=models_to_load) yield app = FastAPI(lifespan=lifespan) diff --git a/app/helpers/config.py b/app/helpers/config.py index 6ea4db6..8fbc3f0 100644 --- a/app/helpers/config.py +++ b/app/helpers/config.py @@ -24,7 +24,6 @@ def __init__( self, config_file: Optional[str] = None, config_data: Optional[Dict] = None, - load_all_models: bool = False, models_to_load:list = [] ): self.loaded_models: Dict = {} @@ -33,7 +32,6 @@ def __init__( self.pair_to_model_id_map: Dict = {} self.config_data: Dict = config_data or {} self.config_file: str = config_file or CONFIG_JSON_PATH - self.load_all_models: bool = load_all_models self.warnings: List[str] = [] self.messages: List[str] = [] @@ -43,7 +41,7 @@ def __init__( self._validate_models() self._load_language_codes() - self._load_models(load_all_models, models_to_load) + self._load_models(models_to_load) self._load_languages_list() def map_lang_to_closest(self, lang: str) -> str: @@ -145,11 +143,11 @@ def _is_valid_model_type(self, model_type: str) -> bool: return False return True - def _load_models(self, load_all, models_to_load) -> None: + def _load_models(self, models_to_load) -> None: for model_config in self.config_data['models']: _, _, model_id = self._get_ser_tgt_model_id(model_config) - if not load_all and model_id not in models_to_load: + if 'all' not in models_to_load and model_id not in models_to_load: continue # CONFIG CHECKS diff --git a/app/tests/api/v1/test_api_translate.py b/app/tests/api/v1/test_api_translate.py index 5a46642..d4476ab 100644 --- a/app/tests/api/v1/test_api_translate.py +++ b/app/tests/api/v1/test_api_translate.py @@ -15,31 +15,31 @@ def setup_before_each_test(self): def test_list_languages(self): with TestClient(self.app) as client: response = client.get(url=self.get_endpoint('/')) - assert response.status_code == status.HTTP_200_OK - content = response.json() - assert content['models'] == {'ca': {'es': ['ca-es']}, 'es': {'ca': ['es-ca']}} - assert content['languages'] == { - "es": "Spanish", - "ca": "Catalan", - "en": "English", - "fr": "French", - "de": "German", - "it": "Italian", - "pt": "Portuguese" - } + assert response.status_code == status.HTTP_200_OK + content = response.json() + assert content['models'] == {'ca': {'es': ['ca-es']}, 'es': {'ca': ['es-ca']}} + assert content['languages'] == { + "es": "Spanish", + "ca": "Catalan", + "en": "English", + "fr": "French", + "de": "German", + "it": "Italian", + "pt": "Portuguese" + } def test_translate_text_valid_code(self): + options = { + 'src': 'es', + 'tgt': 'ca', + 'text': '¿Cómo estás?', + } + expected_translation = 'Com estàs?' with TestClient(self.app) as client: - options = { - 'src': 'es', - 'tgt': 'ca', - 'text': '¿Cómo estás?', - } - expected_translation = 'Com estàs?' response = client.post(self.get_endpoint('/'), content=json.dumps(options)) - assert response.status_code == status.HTTP_200_OK - content = response.json() - assert content['translation'] == expected_translation + assert response.status_code == status.HTTP_200_OK + content = response.json() + assert content['translation'] == expected_translation def test_translate_text_invalid_code(self): options = { @@ -47,8 +47,9 @@ def test_translate_text_invalid_code(self): 'tgt': 'xyz', 'text': 'Hello there, how are you doing?', } - response = self.client.post(self.get_endpoint('/'), content=json.dumps(options)) - assert response.status_code == status.HTTP_406_NOT_ACCEPTABLE + with TestClient(self.app) as client: + response = client.post(self.get_endpoint('/'), content=json.dumps(options)) + assert response.status_code == status.HTTP_406_NOT_ACCEPTABLE def test_batch_translate_text_valid_code(self): options = { @@ -57,7 +58,8 @@ def test_batch_translate_text_valid_code(self): 'texts': ['Hola, ¿Cómo te llamas?', '¿Cómo estás?'], } expected_translations = ['Hola, com et dius?', 'Com estàs?'] - response = self.client.post(url=self.get_endpoint('/batch'), content=json.dumps(options)) + with TestClient(self.app) as client: + response = client.post(url=self.get_endpoint('/batch'), content=json.dumps(options)) assert response.status_code == status.HTTP_200_OK content = response.json() assert content['translation'] == expected_translations @@ -68,5 +70,6 @@ def test_batch_translate_text_invalid_code(self): 'tgt': 'xyz', 'texts': ['Hola, ¿Cómo te llamas?', '¿Cómo estás?'], } - response = self.client.post(url=self.get_endpoint('/batch'), content=json.dumps(options)) + with TestClient(self.app) as client: + response = client.post(url=self.get_endpoint('/batch'), content=json.dumps(options)) assert response.status_code == status.HTTP_406_NOT_ACCEPTABLE diff --git a/app/tests/base_test_case.py b/app/tests/base_test_case.py index 21d9d5a..89ac61e 100644 --- a/app/tests/base_test_case.py +++ b/app/tests/base_test_case.py @@ -1,12 +1,11 @@ from fastapi.testclient import TestClient -from app import create_app +from main import app import os class BaseTestCase: def setup(self): - os.environ['MODELS_ROOT'] = "./models" - self.app = create_app(models_to_load=["es-ca", "ca-es"]) + self.app = app self.client = TestClient(self.app) diff --git a/app/tests/test_load_models.py b/app/tests/test_load_models.py index 71da25c..cd3ac50 100644 --- a/app/tests/test_load_models.py +++ b/app/tests/test_load_models.py @@ -36,7 +36,7 @@ def test_load_models_with_warnings(self): }] } - self.config = Config(config_data=self.config_data, load_all_models=True) + self.config = Config(config_data=self.config_data, models_to_load=["all"]) # languages assert self.config.language_codes == { "es": "Spanish", diff --git a/app/tests/test_translations.py b/app/tests/test_translations.py index 1bcf8d3..26eb54a 100644 --- a/app/tests/test_translations.py +++ b/app/tests/test_translations.py @@ -1,3 +1,4 @@ +""" import json import pytest from app.utils.translate import translate_text @@ -23,9 +24,11 @@ def test_translate_text_es_ca(self): assert translation == expected_translation + def test_translate_text_ca_es(self): model_id = get_model_id('ca', 'es') text = 'Hola, com estàs?' expected_translation = 'Hola, ¿cómo estás?' translation = translate_text(model_id, text, 'ca', 'es') assert translation == expected_translation +""" diff --git a/app/views/v1/translate.py b/app/views/v1/translate.py index 9b1a14a..ea1d019 100644 --- a/app/views/v1/translate.py +++ b/app/views/v1/translate.py @@ -20,7 +20,6 @@ def fetch_model_data_from_request(request): config = Config() - src = config.map_lang_to_closest(request.src) tgt = config.map_lang_to_closest(request.tgt) use_multi = True if request.use_multi == 'True' else False diff --git a/main.py b/main.py index d10b8c9..610f64f 100644 --- a/main.py +++ b/main.py @@ -3,18 +3,24 @@ import uvicorn from app import create_app -if __name__ == "__main__": + +def get_arguments(): parser = argparse.ArgumentParser(description="An API designed to provide translation services for text between different languages.") parser.add_argument("-m", "--models", type=str, default="./models", help="Directory path of models", required=False) parser.add_argument("-l", "--load", type=str, nargs="+", help="Option to load models, if it contains 'all' it will download all models", default=["es-ca", "ca-es"]) + parser.add_argument("-r", "--reload", type=bool, help="Reload api on changes", action=argparse.BooleanOptionalAction) + parser.add_argument("-w", "--workers", type=int, help="Number of workers to run the api", default=None) + parser.add_argument("--host", type=str, help="Host to run the app", default="0.0.0.0") + parser.add_argument("--port", type=int, help="Port to run the app", default=8000) + parser.add_argument("--logs", type=str, help="Logging config file", default="logging.yml") + + return parser - args = parser.parse_args() - os.environ['MODELS_ROOT'] = args.models - models_to_load = args.load - if 'all' in models_to_load: - load_all_models = True - else: - load_all_models = False - app = create_app(load_all_models, models_to_load) - uvicorn.run(app, host="0.0.0.0", port=8000, log_config = "logging.yml") \ No newline at end of file +args = get_arguments().parse_args() +os.environ['MODELS_ROOT'] = args.models + +app = create_app(args.load) + +if __name__ == "__main__": + uvicorn.run("main:app", host=args.host, port=args.port, log_config=args.logs, reload=args.reload, workers=args.workers) \ No newline at end of file diff --git a/run_local.sh b/run_local.sh index 79e985b..7a84f97 100755 --- a/run_local.sh +++ b/run_local.sh @@ -1,3 +1 @@ -export MT_API_CONFIG=config.json -export MODELS_ROOT=./models -python main.py --models './models' --load all \ No newline at end of file +python main.py \ No newline at end of file diff --git a/run_test.sh b/run_test.sh deleted file mode 100755 index 0ad7e64..0000000 --- a/run_test.sh +++ /dev/null @@ -1,3 +0,0 @@ -export MT_API_CONFIG=config.json -export MODELS_ROOT=./models -pytest \ No newline at end of file