diff --git a/fastchat/serve/vllm_worker.py b/fastchat/serve/vllm_worker.py index 0af680bb5..828b33420 100644 --- a/fastchat/serve/vllm_worker.py +++ b/fastchat/serve/vllm_worker.py @@ -24,6 +24,9 @@ ) from fastchat.utils import get_context_length, is_partial_stop +import os + +os.environ["VLLM_USE_V1"] = os.environ.get("VLLM_USE_V1", "0") app = FastAPI() @@ -71,7 +74,7 @@ async def generate_stream(self, params): request_id = params.pop("request_id") temperature = float(params.get("temperature", 1.0)) top_p = float(params.get("top_p", 1.0)) - top_k = params.get("top_k", -1.0) + top_k = params.get("top_k", -1) presence_penalty = float(params.get("presence_penalty", 0.0)) frequency_penalty = float(params.get("frequency_penalty", 0.0)) max_new_tokens = params.get("max_new_tokens", 256) @@ -107,7 +110,7 @@ async def generate_stream(self, params): n=1, temperature=temperature, top_p=top_p, - use_beam_search=use_beam_search, + # use_beam_search=use_beam_search, stop=list(stop), stop_token_ids=stop_token_ids, max_tokens=max_new_tokens, @@ -156,9 +159,11 @@ async def generate_stream(self, params): "cumulative_logprob": [ output.cumulative_logprob for output in request_output.outputs ], - "finish_reason": request_output.outputs[0].finish_reason - if len(request_output.outputs) == 1 - else [output.finish_reason for output in request_output.outputs], + "finish_reason": ( + request_output.outputs[0].finish_reason + if len(request_output.outputs) == 1 + else [output.finish_reason for output in request_output.outputs] + ), } # Emit twice here to ensure a 'finish_reason' with empty content in the OpenAI API response. # This aligns with the behavior of model_worker.