diff --git a/libs/infinity_emb/infinity_emb/transformer/acceleration.py b/libs/infinity_emb/infinity_emb/transformer/acceleration.py index 1d7b7c7f..1dcee6b2 100644 --- a/libs/infinity_emb/infinity_emb/transformer/acceleration.py +++ b/libs/infinity_emb/infinity_emb/transformer/acceleration.py @@ -8,10 +8,21 @@ from infinity_emb.primitives import Device if CHECK_OPTIMUM.is_available: - from optimum.bettertransformer import ( # type: ignore[import-untyped] - BetterTransformer, - BetterTransformerManager, - ) + try: + from optimum.bettertransformer import ( # type: ignore[import-untyped] + BetterTransformer, + BetterTransformerManager, + ) + BETTERTRANSFORMER_AVAILABLE = True + except (ImportError, RuntimeError): + # BetterTransformer is deprecated in newer versions of optimum + BETTERTRANSFORMER_AVAILABLE = False + BetterTransformer = None + BetterTransformerManager = None +else: + BETTERTRANSFORMER_AVAILABLE = False + BetterTransformer = None + BetterTransformerManager = None if CHECK_TORCH.is_available: import torch @@ -37,6 +48,9 @@ def check_if_bettertransformer_possible(engine_args: "EngineArgs") -> bool: if not engine_args.bettertransformer: return False + if not BETTERTRANSFORMER_AVAILABLE: + return False + config = AutoConfig.from_pretrained( pretrained_model_name_or_path=engine_args.model_name_or_path, revision=engine_args.revision, @@ -50,6 +64,10 @@ def to_bettertransformer(model: "PreTrainedModel", engine_args: "EngineArgs", lo if not engine_args.bettertransformer: return model + if not BETTERTRANSFORMER_AVAILABLE: + logger.info("BetterTransformer is not available due to version incompatibility. Continuing without optimization.") + return model + if engine_args.device == Device.mps or ( hasattr(model, "device") and model.device.type == "mps" ): diff --git a/libs/infinity_emb/pyproject.toml b/libs/infinity_emb/pyproject.toml index 48a99694..20cf3c2d 100644 --- a/libs/infinity_emb/pyproject.toml +++ b/libs/infinity_emb/pyproject.toml @@ -80,7 +80,7 @@ jinja2-cli = "*" torch = "2.8.0" prometheus-fastapi-instrumentator = "7.0.0" # sentence-transformers = "3.3.1" -transformers = "4.47.0" +transformers = "4.53.3" fastapi = "0.115.2" [tool.poetry.group.codespell.dependencies]