From 8fe76b968ba583033a9b94a3e4a15e36d7090288 Mon Sep 17 00:00:00 2001 From: sunyi001 <1659275352@qq.com> Date: Fri, 29 Nov 2024 10:51:08 +0800 Subject: [PATCH] Specify ASCEND NPU for inference. --- fastchat/serve/cli.py | 9 +++++++++ fastchat/serve/model_worker.py | 9 +++++++++ fastchat/serve/multi_model_worker.py | 9 +++++++++ 3 files changed, 27 insertions(+) diff --git a/fastchat/serve/cli.py b/fastchat/serve/cli.py index 78f7f51b1..34d511131 100644 --- a/fastchat/serve/cli.py +++ b/fastchat/serve/cli.py @@ -13,6 +13,7 @@ - Type "!!save " to save the conversation history to a json file. - Type "!!load " to load a conversation history from a json file. """ + import argparse import os import re @@ -197,6 +198,14 @@ def main(args): ) os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus os.environ["XPU_VISIBLE_DEVICES"] = args.gpus + if len(args.gpus.split(",")) == 1: + try: + import torch_npu + + torch.npu.set_device(int(args.gpus)) + print(f"NPU is available, now model is running on npu:{args.gpus}") + except ModuleNotFoundError: + pass if args.enable_exllama: exllama_config = ExllamaConfig( max_seq_len=args.exllama_max_seq_len, diff --git a/fastchat/serve/model_worker.py b/fastchat/serve/model_worker.py index 683a78556..2043fb5e9 100644 --- a/fastchat/serve/model_worker.py +++ b/fastchat/serve/model_worker.py @@ -1,6 +1,7 @@ """ A model worker that executes the model. """ + import argparse import base64 import gc @@ -351,6 +352,14 @@ def create_model_worker(): f"Larger --num-gpus ({args.num_gpus}) than --gpus {args.gpus}!" ) os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus + if len(args.gpus.split(",")) == 1: + try: + import torch_npu + + torch.npu.set_device(int(args.gpus)) + print(f"NPU is available, now model is running on npu:{args.gpus}") + except ModuleNotFoundError: + pass gptq_config = GptqConfig( ckpt=args.gptq_ckpt or args.model_path, diff --git a/fastchat/serve/multi_model_worker.py b/fastchat/serve/multi_model_worker.py index 5e6266fe0..dfbb4dbaf 100644 --- a/fastchat/serve/multi_model_worker.py +++ b/fastchat/serve/multi_model_worker.py @@ -11,6 +11,7 @@ We recommend using this with multiple Peft models (with `peft` in the name) where all Peft models are trained on the exact same base model. """ + import argparse import asyncio import dataclasses @@ -206,6 +207,14 @@ def create_multi_model_worker(): f"Larger --num-gpus ({args.num_gpus}) than --gpus {args.gpus}!" ) os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus + if len(args.gpus.split(",")) == 1: + try: + import torch_npu + + torch.npu.set_device(int(args.gpus)) + print(f"NPU is available, now model is running on npu:{args.gpus}") + except ModuleNotFoundError: + pass gptq_config = GptqConfig( ckpt=args.gptq_ckpt or args.model_path,