-
Notifications
You must be signed in to change notification settings - Fork 689
[Feature] [KVCache] support attention_store kv cache backend #5823
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
bb53e2c
8261f6d
cb2d51d
92a0391
a4c9904
2b64879
ad841d1
4af8c66
fda4597
f11e4ad
34bdd42
030ce32
48617d1
36f889b
3d6bd87
9625602
ef14c9c
117d28a
9df86a7
2b48659
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,37 @@ | ||
| """ | ||
| # Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License" | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| """ | ||
|
|
||
| from dataclasses import dataclass | ||
| from typing import List | ||
|
|
||
|
|
||
| @dataclass(frozen=True, kw_only=True) | ||
| class CacheTask: | ||
| task_id: str | ||
| keys: List[str] | ||
| token_ids: List[int] | ||
| gpu_block_ids: List[int] | ||
|
|
||
|
|
||
liyonghua0910 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| @dataclass(frozen=True, kw_only=True) | ||
| class ReadStorageTask(CacheTask): | ||
| start_read_block_idx: int | ||
| timeout: float = 30.0 | ||
|
|
||
|
|
||
| @dataclass(frozen=True, kw_only=True) | ||
| class WriteStorageTask(CacheTask): | ||
| timeout: float = 30.0 | ||
Large diffs are not rendered by default.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -31,7 +31,9 @@ | |
| from fastdeploy import envs | ||
| from fastdeploy.cache_manager.cache_data import BlockNode, CacheStatus | ||
| from fastdeploy.cache_manager.cache_metrics import CacheMetrics | ||
| from fastdeploy.cache_manager.cache_tasks import ReadStorageTask, WriteStorageTask | ||
| from fastdeploy.cache_manager.ops import get_all_visible_devices | ||
| from fastdeploy.config import FDConfig | ||
| from fastdeploy.engine.request import Request | ||
| from fastdeploy.inter_communicator import EngineCacheQueue, IPCSignal, PrefixTreeStatus | ||
| from fastdeploy.metrics.metrics import main_process_metrics | ||
|
|
@@ -47,7 +49,7 @@ class PrefixCacheManager: | |
|
|
||
| def __init__( | ||
| self, | ||
| config, | ||
| config: FDConfig, | ||
| tensor_parallel_size, | ||
| splitwise_role="mixed", | ||
| local_data_parallel_id=0, | ||
|
|
@@ -207,7 +209,6 @@ def launch_cache_manager( | |
| key_cache_shape, val_cache_shape = self._get_kv_cache_shape(cache_config.total_block_num) | ||
| key_cache_shape = ",".join([str(i) for i in key_cache_shape]) | ||
| val_cache_shape = ",".join([str(i) for i in val_cache_shape]) | ||
| logger.info(f"key_cache_shape {key_cache_shape} value_cache_shape {val_cache_shape}") | ||
| if self.enable_splitwise: | ||
| cache_messager_processes = self.launch_cache_messager( | ||
| cache_config, | ||
|
|
@@ -273,6 +274,7 @@ def launch_cache_manager( | |
| + " NCCL_MAX_NCHANNELS=1 NCCL_BUFFSIZE=0" | ||
| + f" FD_ENABLE_SWAP_SPACE_CLEARING={envs.FD_ENABLE_SWAP_SPACE_CLEARING}" | ||
| + f" {sys.executable} {py_path}" | ||
| + f" --model_id {os.path.basename(self.config.model_config.model)}" | ||
| + f" --device_id {int(device_ids[i])}" | ||
| + f" --rank {i}" | ||
| + f" --splitwise_role {self.splitwise_role}" | ||
|
|
@@ -390,7 +392,7 @@ def launch_cache_messager( | |
| + f" --ipc_suffix {ipc_suffix}" | ||
| + f" --rdma_port {cache_config.local_rdma_comm_ports[i] if cache_config.local_rdma_comm_ports is not None else '0'}" | ||
| + f" --speculative_config '{self.speculative_config.to_json_string()}'" | ||
| + f" >{log_dir}/launch_cache_messager_tprank{i}.log 2>&1" | ||
| + f" >{log_dir}/launch_cache_messager_{i}.log 2>&1" | ||
| ) | ||
| logger.info(f"Launch cache messager, command:{launch_cmd}") | ||
| cache_messager_processes.append(subprocess.Popen(launch_cmd, shell=True, preexec_fn=os.setsid)) | ||
|
|
@@ -789,9 +791,15 @@ def request_match_blocks(self, task: Request, block_size, *args): | |
| f"start prefetch cache from storage, req_id: {req_id}, block num: {len(no_match_block_keys)}" | ||
| ) | ||
| start_time = time.time() | ||
| storage_matched_block_ids = self.issue_prefetch_storage_task( | ||
| req_id, no_match_block_keys, gpu_recv_storage_block_ids | ||
| read_storage_task = ReadStorageTask( | ||
| task_id=req_id, | ||
| keys=no_match_block_keys, | ||
| token_ids=input_token_ids, | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 如果长文带上token_ids,跨进程通信可能偏重,可以判断下mooncake就不带上token_ids了,看transfer_manager.py也只有as需要token_ids
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 嗯可以 |
||
| gpu_block_ids=gpu_recv_storage_block_ids, | ||
| start_read_block_idx=match_token_num // block_size, | ||
| ) | ||
| logger.debug(f"issue read storage task: {read_storage_task}") | ||
| storage_matched_block_ids = self.issue_prefetch_storage_task(read_storage_task) | ||
| storage_matched_block_num = len(storage_matched_block_ids) | ||
| storage_match_token_num = storage_matched_block_num * block_size | ||
| cost_time = time.time() - start_time | ||
|
|
@@ -1006,6 +1014,12 @@ def write_cache_to_storage(self, request: Request): | |
| if self.kvcache_storage_backend is None: | ||
| return | ||
|
|
||
| token_ids = request.prompt_token_ids | ||
| if isinstance(token_ids, np.ndarray): | ||
| token_ids = token_ids.tolist() | ||
| if self.config.cache_config.enable_output_caching: | ||
| token_ids += request.output_token_ids | ||
|
|
||
| req_id = request.request_id | ||
| keys = [] | ||
| node = self.req_leaf_map[req_id] | ||
|
|
@@ -1018,24 +1032,33 @@ def write_cache_to_storage(self, request: Request): | |
|
|
||
| gpu_block_ids = request.block_tables[: len(keys)] | ||
| logger.info(f"start write cache back to storage, req_id: {req_id}, block num: {len(keys)}") | ||
| write_storage_task = WriteStorageTask( | ||
| task_id=req_id, | ||
| keys=keys, | ||
| token_ids=token_ids, | ||
| gpu_block_ids=gpu_block_ids, | ||
| ) | ||
| logger.debug(f"issue write storage task: {write_storage_task}") | ||
| tic = time.time() | ||
| self.issue_write_back_storage_task(req_id=req_id, hash_keys=keys, gpu_block_ids=gpu_block_ids, is_sync=True) | ||
| self.issue_write_back_storage_task(write_storage_task, is_sync=True) | ||
| cost_time = time.time() - tic | ||
| logger.info(f"finish write cache back to storage, req_id: {req_id}, cost_time: {cost_time:.6f}s") | ||
|
|
||
| def issue_write_back_storage_task(self, req_id, hash_keys, gpu_block_ids, is_sync=True, timeout=0.5): | ||
| def issue_write_back_storage_task(self, task: WriteStorageTask, is_sync=True): | ||
| if self.kvcache_storage_backend is None: | ||
| return | ||
|
|
||
| if len(hash_keys) != len(gpu_block_ids): | ||
| err_msg = f"write_back_storage error: hash_keys({len(hash_keys)}) != gpu_block_ids({len(gpu_block_ids)})" | ||
| if len(task.keys) != len(task.gpu_block_ids): | ||
| err_msg = ( | ||
| f"write_back_storage error: hash_keys({len(task.keys)}) != gpu_block_ids({len(task.gpu_block_ids)})" | ||
| ) | ||
| logger.error(err_msg) | ||
| raise ValueError(err_msg) | ||
|
|
||
| self.task_write_back_event[req_id] = Event() | ||
| self.cache_task_queue.put_transfer_task((CacheStatus.GPU2STORAGE, req_id, hash_keys, gpu_block_ids, timeout)) | ||
| self.task_write_back_event[task.task_id] = Event() | ||
| self.cache_task_queue.put_transfer_task((CacheStatus.GPU2STORAGE, task)) | ||
| if is_sync: | ||
| self.wait_write_storage_task(req_id) | ||
| self.wait_write_storage_task(task.task_id) | ||
|
|
||
| def wait_write_storage_task(self, req_id): | ||
| """ | ||
|
|
@@ -1045,16 +1068,19 @@ def wait_write_storage_task(self, req_id): | |
| self.task_write_back_event[req_id].wait() | ||
| del self.task_write_back_event[req_id] | ||
|
|
||
| def issue_prefetch_storage_task(self, req_id, hash_keys, gpu_block_ids, is_sync=True, timeout=0.5): | ||
| def issue_prefetch_storage_task(self, task: ReadStorageTask, is_sync=True): | ||
| """ | ||
| Prefetch cache from storage task | ||
| """ | ||
| if self.kvcache_storage_backend is None: | ||
| return [] | ||
|
|
||
| storage_block_ids = [] | ||
| self.task_prefetch_event[req_id] = Event() | ||
| self.task_prefetch_event[task.task_id] = Event() | ||
| # issue task to cache_transfer_manager | ||
| self.cache_task_queue.put_transfer_task((CacheStatus.STORAGE2GPU, req_id, hash_keys, gpu_block_ids, timeout)) | ||
| self.cache_task_queue.put_transfer_task((CacheStatus.STORAGE2GPU, task)) | ||
| if is_sync: | ||
| storage_block_ids = self.wait_prefetch_storage_task(req_id) | ||
| storage_block_ids = self.wait_prefetch_storage_task(task.task_id) | ||
| return storage_block_ids | ||
|
|
||
| def wait_prefetch_storage_task(self, req_id): | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.