From 67feb875f5d1a19ad17dd1a0d453453fbdff0722 Mon Sep 17 00:00:00 2001 From: Jimin Park Date: Mon, 8 Dec 2025 07:01:49 +0000 Subject: [PATCH] ShareGPT input len --- vllm/benchmarks/datasets.py | 98 ++++++++++++++++++++++++++++--------- 1 file changed, 75 insertions(+), 23 deletions(-) diff --git a/vllm/benchmarks/datasets.py b/vllm/benchmarks/datasets.py index ec9b0fd6e969..d4c6e5ee1c6a 100644 --- a/vllm/benchmarks/datasets.py +++ b/vllm/benchmarks/datasets.py @@ -1235,6 +1235,7 @@ def sample( num_requests: int, lora_path: str | None = None, max_loras: int | None = None, + input_len: int | None = None, output_len: int | None = None, enable_multimodal_chat: bool = False, request_id_prefix: str = "", @@ -1242,6 +1243,9 @@ def sample( **kwargs, ) -> list: samples: list = [] + new_prompt_cnt = 0 + new_prompt = "" + new_prompt_len = 0 ind = 0 for entry in self.data: if len(samples) >= num_requests: @@ -1258,30 +1262,70 @@ def sample( completion_ids = tokenizer(completion).input_ids prompt_len = len(prompt_ids) new_output_len = len(completion_ids) if output_len is None else output_len - if not is_valid_sequence( - prompt_len, - new_output_len, - skip_min_output_len_check=output_len is not None, - ): - continue - if image_path := entry.get("image"): - mm_content = process_image(image_path) - elif video_path := entry.get("video"): - mm_content = process_video(video_path) + if input_len: + # Make fixed size len + remaining_len = input_len - new_prompt_len + if prompt_len > input_len: + continue + if remaining_len > prompt_len: + new_prompt = new_prompt + ' ' + prompt + new_prompt_ids = tokenizer(new_prompt).input_ids + new_prompt_len = len(new_prompt_ids) + continue + else: + new_prompt = new_prompt + tokenizer.decode(prompt_ids[:remaining_len]) + new_prompt_ids = tokenizer(new_prompt).input_ids + new_prompt_len = len(new_prompt_ids) + if not is_valid_sequence( + new_prompt_len, + new_output_len, + min_len=input_len, + max_prompt_len=input_len, + max_total_len=input_len+output_len, + skip_min_output_len_check=output_len is not None, + ): + continue + if enable_multimodal_chat: + new_prompt = self.apply_multimodal_chat_transformation(new_prompt, None) + samples.append( + SampleRequest( + prompt=new_prompt, + prompt_len=new_prompt_len, + expected_output_len=new_output_len, + lora_request=lora_request, + multi_modal_data=None, + request_id=request_id_prefix + str(ind), + ) + ) + new_prompt_cnt = new_prompt_cnt + 1 + if int(new_prompt_cnt % (num_requests / 10)) == 0: + print(f"[{new_prompt_cnt}/{num_requests}] new prompts are created") + new_prompt = "" + new_prompt_len = 0 else: - mm_content = None - if enable_multimodal_chat: - prompt = self.apply_multimodal_chat_transformation(prompt, mm_content) - samples.append( - SampleRequest( - prompt=prompt, - prompt_len=prompt_len, - expected_output_len=new_output_len, - lora_request=lora_request, - multi_modal_data=mm_content, - request_id=request_id_prefix + str(ind), - ) - ) + if not is_valid_sequence(prompt_len, + new_output_len, + skip_min_output_len_check=output_len + is not None): + continue + if image_path := entry.get("image"): + mm_content = process_image(image_path) + elif video_path := entry.get("video"): + mm_content = process_video(video_path) + else: + mm_content = None + if enable_multimodal_chat: + prompt = self.apply_multimodal_chat_transformation( + prompt, mm_content) + samples.append( + SampleRequest( + prompt=prompt, + prompt_len=prompt_len, + expected_output_len=new_output_len, + lora_request=lora_request, + multi_modal_data=mm_content, + request_id=request_id_prefix + str(ind), + )) ind += 1 self.maybe_oversample_requests( samples, num_requests, request_id_prefix, no_oversample @@ -1409,6 +1453,13 @@ def add_dataset_parser(parser: FlexibleArgumentParser): ) sharegpt_group = parser.add_argument_group("sharegpt dataset options") + sharegpt_group.add_argument( + "--sharegpt-input-len", + type=int, + default=None, + help="Input length for each request. Overrides the input length " + "from the ShareGPT dataset.", + ) sharegpt_group.add_argument( "--sharegpt-output-len", type=int, @@ -1824,6 +1875,7 @@ def get_samples(args, tokenizer) -> list[SampleRequest]: ).sample( tokenizer=tokenizer, num_requests=args.num_prompts, + input_len=args.sharegpt_input_len, output_len=args.sharegpt_output_len, request_id_prefix=args.request_id_prefix, no_oversample=args.no_oversample,