From bb53e2c0482a427c649e570e61f1a2113c29efa0 Mon Sep 17 00:00:00 2001 From: liyonghua0910 Date: Mon, 29 Dec 2025 12:19:36 +0000 Subject: [PATCH 01/16] [feat] support attention_store kv cache backend --- fastdeploy/cache_manager/cache_messager.py | 5 +- fastdeploy/cache_manager/cache_tasks.py | 19 + .../cache_manager/cache_transfer_manager.py | 398 ++++++++++-------- .../cache_manager/prefix_cache_manager.py | 42 +- .../transfer_factory/__init__.py | 3 +- .../transfer_factory/kvcache_storage.py | 7 + .../mooncake_store/__init__.py | 3 +- .../mooncake_store/attention_store.py | 191 +++++++++ .../mooncake_store/mooncake_store.py | 15 + fastdeploy/engine/args_utils.py | 3 +- .../engine/sched/resource_manager_v1.py | 11 +- 11 files changed, 499 insertions(+), 198 deletions(-) create mode 100644 fastdeploy/cache_manager/cache_tasks.py create mode 100644 fastdeploy/cache_manager/transfer_factory/mooncake_store/attention_store.py diff --git a/fastdeploy/cache_manager/cache_messager.py b/fastdeploy/cache_manager/cache_messager.py index dd56b195bdc..de877f61e4c 100644 --- a/fastdeploy/cache_manager/cache_messager.py +++ b/fastdeploy/cache_manager/cache_messager.py @@ -1029,7 +1029,10 @@ def main(): args = parse_args() rank_id = args.rank + args.local_data_parallel_id * args.mp_num - logger = get_logger("cache_messager", f"cache_messager_tprank{args.rank}.log") + if args.mp_num > 1: + logger = get_logger("cache_messager", f"cache_messager.log.{rank_id}") + else: + logger = get_logger("cache_messager", f"cache_messager.log") logger.info("create cache messager...") logger.info(f"{args}") diff --git a/fastdeploy/cache_manager/cache_tasks.py b/fastdeploy/cache_manager/cache_tasks.py new file mode 100644 index 00000000000..9d281a98f4b --- /dev/null +++ b/fastdeploy/cache_manager/cache_tasks.py @@ -0,0 +1,19 @@ + +from dataclasses import dataclass +from typing import List, Optional + +@dataclass(frozen=True, kw_only=True) +class CacheTask: + task_id: str + keys: List[str] + token_ids: List[int] + gpu_block_ids: List[int] + +@dataclass(frozen=True, kw_only=True) +class ReadStorageTask(CacheTask): + start_read_block_idx: int + timeout: float = 30.0 + +@dataclass(frozen=True, kw_only=True) +class WriteStorageTask(CacheTask): + timeout: float = 30.0 \ No newline at end of file diff --git a/fastdeploy/cache_manager/cache_transfer_manager.py b/fastdeploy/cache_manager/cache_transfer_manager.py index faf9255441c..7135484f113 100644 --- a/fastdeploy/cache_manager/cache_transfer_manager.py +++ b/fastdeploy/cache_manager/cache_transfer_manager.py @@ -23,11 +23,13 @@ import time import traceback from typing import List +from math import prod import numpy as np import paddle from fastdeploy import envs +from fastdeploy.cache_manager.cache_tasks import ReadStorageTask, WriteStorageTask from fastdeploy.cache_manager.cache_data import CacheStatus from fastdeploy.cache_manager.ops import ( cuda_host_alloc, @@ -40,7 +42,7 @@ swap_cache_layout, unset_data_ipc, ) -from fastdeploy.cache_manager.transfer_factory import MooncakeStore +from fastdeploy.cache_manager.transfer_factory import MooncakeStore, AttentionStore from fastdeploy.config import SpeculativeConfig from fastdeploy.inter_communicator import EngineCacheQueue, IPCSignal, KVCacheStatus from fastdeploy.platforms import current_platform @@ -109,7 +111,7 @@ def parse_args(): "--kvcache_storage_backend", type=str, default=None, - choices=["mooncake", "none"], + choices=["mooncake", "attention_store", "none"], help="The storage backend for kvcache storage. If not set, storage backend is disabled.", ) parser.add_argument( @@ -133,8 +135,6 @@ def __init__(self, args): """ 初始化CacheTransferManager """ - device = args.device_id - rank = args.rank self.gpu_cache_kvs = {} self.cpu_cache_kvs = {} self.gpu_cache_k_tensors = [] @@ -142,11 +142,27 @@ def __init__(self, args): self.gpu_cache_scales_k_tensors = [] self.gpu_cache_scales_v_tensors = [] self.speculative_config = SpeculativeConfig(args.speculative_config) + + # parse kv cache shape self.key_cache_shape = [int(i) for i in args.key_cache_shape.split(",")] self.value_cache_shape = [] if args.value_cache_shape: self.value_cache_shape = [int(i) for i in args.value_cache_shape.split(",")] + + # extract kv cache shape into fields self.num_gpu_blocks = self.key_cache_shape[0] + self.head_num = self.key_cache_shape[1] + self.block_size = self.key_cache_shape[2] + self.head_dim = self.key_cache_shape[3] + + # extract other arg values + self.n_ranks = args.mp_num + self.rank = args.rank + self.device = args.device_id + self.num_layers = args.num_layers + self.ipc_suffix = args.ipc_suffix + self.cache_dtype = args.cache_dtype + self.local_data_parallel_id = args.local_data_parallel_id self.num_extra_layers = self.speculative_config.num_extra_cache_layer self.num_extra_layer_gpu_blocks = int(self.num_gpu_blocks * self.speculative_config.num_gpu_block_expand_ratio) paddle.set_default_dtype(args.default_dtype) @@ -158,18 +174,13 @@ def __init__(self, args): self.timeout_thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=2) self.transfer_task_queue = queue.Queue() # 用来接收传输任务 self.tansfer_done_queue = queue.Queue() # 用来告知任务执行完毕 - self.n_ranks = args.mp_num - self.rank = rank - self.device = device - self.ipc_suffix = args.ipc_suffix - self.cache_dtype = args.cache_dtype address = (args.pod_ip, args.cache_queue_port) self.cache_task_queue = EngineCacheQueue( address=address, is_server=False, num_client=args.mp_num, - client_id=rank, + client_id=self.rank, local_data_parallel_id=args.local_data_parallel_id, ) @@ -223,8 +234,21 @@ def __init__(self, args): self.storage_backend = MooncakeStore(tp_rank=self.rank) self._init_storage_buffer(args) logger.info("Initialized mooncake store successfully") + elif args.kvcache_storage_backend == "attention_store": + logger.info("Start initialize attention store...") + self.storage_backend = AttentionStore( + shard_id=self.rank, + shard_num=self.n_ranks, + layer_num=self.num_layers + self.num_extra_layers, + block_token_size=self.key_cache_shape[2], # key_cache_shape: [num_blocks, head_num, block_size, head_dim] + bytes_per_shard_layer_per_block=prod(self.key_cache_shape[1:]), + device_id=self.device, + dp_id=self.local_data_parallel_id, + ) + logger.info("Initialized attention store successfully!") else: raise NotImplementedError(f"Unsupported storage backend: {args.kvcache_storage_backend}") + self.storage_backend_type = args.kvcache_storage_backend if args.write_policy not in ["write_through"]: raise ValueError(f"Invalid write policy: {args.write_policy}") @@ -238,7 +262,7 @@ def _init_storage_buffer(self, args): cache layout: layer_num * [block_num, head_num, block_size, head_dim] buffer layout: [block_num, layer_num, head_num, block_size, head_dim] """ - layer_num = args.num_layers + self.num_extra_layers + layer_num = self.num_layers + self.num_extra_layers head_num = self.key_cache_shape[1] block_size = self.key_cache_shape[2] head_dim = self.key_cache_shape[3] @@ -288,8 +312,8 @@ def _init_gpu_cache(self, args): logger.info(f"[rank {self.rank}/{self.n_ranks}] Initializing kv cache for all layers.") set_device(self.device) - for i in range(args.num_layers + self.num_extra_layers): - num_gpu_blocks = self.num_gpu_blocks if i < args.num_layers else self.num_extra_layer_gpu_blocks + for i in range(self.num_layers + self.num_extra_layers): + num_gpu_blocks = self.num_gpu_blocks if i < self.num_layers else self.num_extra_layer_gpu_blocks key_name = f"key_caches_{i}_rank{self.rank}.device{self.device}" val_name = f"value_caches_{i}_rank{self.rank}.device{self.device}" key_cache_scales_name = f"key_cache_scales_{i}_rank{self.rank}.device{self.device}" @@ -407,7 +431,7 @@ def _init_cpu_cache(self, args): self.v_dst_ptrs = [] self.k_scales_ptrs = [] self.v_scales_ptrs = [] - for i in range(args.num_layers + self.num_extra_layers): + for i in range(self.num_layers + self.num_extra_layers): key_name = f"key_caches_{i}_rank{self.rank}" val_name = f"value_caches_{i}_rank{self.rank}" key_cache_scales_name = f"key_cache_scales_{i}_rank{self.rank}" @@ -438,211 +462,235 @@ def _get_cache_bytes(self, cache_dtype): raise ValueError(f"Unsupported cache dtype: {cache_dtype}") return cache_bytes - def _storage_exist_block_num(self, k_keys: List[str], v_keys: List[str]): + def _run_read_storage( + self, + token_ids: List[int], + start_read_block_idx: int, + k_cache_keys: List[str], + v_cache_keys: List[str], + gpu_block_ids: List[int], + cpu_block_ids: List[int], + timeout: float, + ): """ - Given the k_keys and v_keys, get the valid blocks number that - can be prefetched from storage backend. + Read storage data from the given blocks to the corresponding cache tensors on the current rank's GPU. """ - assert len(k_keys) == len(v_keys), "k_keys and v_keys must have the same length." - result = self.storage_backend.exists(k_keys + v_keys) - - # only consider the case when both key and value exist - num = 0 - for k, v in zip(k_keys, v_keys): - if result[k] and result[v]: - num += 1 - return num - - def _run_read_storage(self, k_cache_keys, v_cache_keys, gpu_block_ids, cpu_block_ids): try: - logger.debug( - f"_run_read_storage, key_hash_keys: {k_cache_keys}, " - f"value_hash_keys: {v_cache_keys}, gpu_block_ids: {gpu_block_ids}" - ) + if self.storage_backend_type == "mooncake": + k_cache_keys = k_cache_keys[:start_read_block_idx] + v_cache_keys = v_cache_keys[:start_read_block_idx] + gpu_block_ids = gpu_block_ids[:start_read_block_idx] + cpu_block_ids = cpu_block_ids[:start_read_block_idx] + + block_num = len(gpu_block_ids) + keys = k_cache_keys + v_cache_keys + k_cache_ptrs = [self.storage_key_read_buffer + i * self.storage_buffer_stride_bytes for i in cpu_block_ids] + v_cache_ptrs = [self.storage_value_read_buffer + i * self.storage_buffer_stride_bytes for i in cpu_block_ids] + kv_cache_ptrs = k_cache_ptrs + v_cache_ptrs + kv_block_sizes = [self.storage_buffer_stride_bytes] * block_num * 2 # key and value + result = self.storage_backend.batch_get(keys, target_locations=kv_cache_ptrs, target_sizes=kv_block_sizes) # TODO(liyonghua): impl for attention store + + k_result, v_result = result[:block_num], result[block_num:] + success_block_num = 0 + for k, v in zip(k_result, v_result): + if k > 0 and v > 0: + success_block_num += 1 + logger.debug(f"_run_read_storage, success_block_num: {success_block_num}") + valid_gpu_block_ids = gpu_block_ids[:success_block_num] + valid_cpu_block_ids = cpu_block_ids[:success_block_num] + + mode = 1 # cpu ==> gpu + swap_cache_layout( + self.gpu_cache_k_tensors, + self.storage_key_read_buffer, + self.key_cache_shape, + valid_gpu_block_ids, + valid_cpu_block_ids, + self.device, + mode, + ) + swap_cache_layout( + self.gpu_cache_v_tensors, + self.storage_value_read_buffer, + self.value_cache_shape, + valid_gpu_block_ids, + valid_cpu_block_ids, + self.device, + mode, + ) - block_num = len(gpu_block_ids) - keys = k_cache_keys + v_cache_keys - k_cache_ptrs = [self.storage_key_read_buffer + i * self.storage_buffer_stride_bytes for i in cpu_block_ids] - v_cache_ptrs = [ - self.storage_value_read_buffer + i * self.storage_buffer_stride_bytes for i in cpu_block_ids - ] - kv_cache_ptrs = k_cache_ptrs + v_cache_ptrs - kv_block_sizes = [self.storage_buffer_stride_bytes] * block_num * 2 # key and value - result = self.storage_backend.batch_get(keys, target_locations=kv_cache_ptrs, target_sizes=kv_block_sizes) - - k_result, v_result = result[:block_num], result[block_num:] - success_block_num = 0 - for k, v in zip(k_result, v_result): - if k > 0 and v > 0: - success_block_num += 1 - logger.debug(f"_run_read_storage, success_block_num: {success_block_num}") - valid_gpu_block_ids = gpu_block_ids[:success_block_num] - valid_cpu_block_ids = cpu_block_ids[:success_block_num] - - mode = 1 # cpu ==> gpu - swap_cache_layout( - self.gpu_cache_k_tensors, - self.storage_key_read_buffer, - self.key_cache_shape, - valid_gpu_block_ids, - valid_cpu_block_ids, - self.device, - mode, - ) - swap_cache_layout( - self.gpu_cache_v_tensors, - self.storage_value_read_buffer, - self.value_cache_shape, - valid_gpu_block_ids, - valid_cpu_block_ids, - self.device, - mode, - ) + elif self.storage_backend_type == "attention_store": + key_cache = [] + val_cache = [] + for i in range(self.num_layers + self.num_extra_layers): + key_cache.append(self.gpu_cache_kvs[f"key_caches_{i}_rank{self.rank}.device{self.device}"]) + val_cache.append(self.gpu_cache_kvs[f"value_caches_{i}_rank{self.rank}.device{self.device}"]) + read_block_num = self.storage_backend.read(key_cache, val_cache, token_ids, gpu_block_ids, start_read_block_idx, timeout) + valid_gpu_block_ids = gpu_block_ids[:read_block_num] return valid_gpu_block_ids + except Exception as e: logger.error( - f"[rank {self.rank}/{self.n_ranks}] An error occurred in _run_read_storage: " - f"error:{e}, {traceback.format_exc()}" + f"An error occurred in _run_read_storage, " + f"error: {e}, traceback:\n{traceback.format_exc()}" ) raise - def read_storage_task(self, task_id, keys, gpu_block_ids, timeout=0.1): + def read_storage_task(self, task: ReadStorageTask): """Read cache from the storage backend to the GPU memory.""" try: - logger.debug( - f"read_storage_task, task id: {task_id}, hash_keys: {keys}, " - f"gpu_block_ids: {gpu_block_ids}, timeout: {timeout}" - ) - k_cache_keys = [f"{key}_key_{self.rank}" for key in keys] - v_cache_keys = [f"{key}_value_{self.rank}" for key in keys] - match_block_num = self._storage_exist_block_num(k_cache_keys, v_cache_keys) - logger.debug(f"read_storage_task, match {match_block_num} blocks from storage for task id: {task_id}") - - k_cache_keys = k_cache_keys[:match_block_num] - v_cache_keys = v_cache_keys[:match_block_num] - gpu_block_ids = gpu_block_ids[:match_block_num] - cpu_block_ids = [i for i in range(match_block_num)] - valid_gpu_block_ids = [] + gpu_block_ids = task.gpu_block_ids.copy() + cpu_block_ids = [i for i in range(len(gpu_block_ids))] + k_cache_keys = [f"{key}_key_{self.rank}" for key in task.keys] + v_cache_keys = [f"{key}_value_{self.rank}" for key in task.keys] + match_block_num = 0 + if self.storage_backend_type == "mooncake": + match_block_num = self.storage_backend.query(k_cache_keys, v_cache_keys) + elif self.storage_backend_type == "attention_store": + match_block_num = self.storage_backend.query(task.token_ids, task.start_read_block_idx, task.timeout) + logger.info(f"Before reading cache from storage, found {match_block_num} blocks already cached in storage for task {task.task_id}") + + valid_gpu_block_ids = [] if match_block_num > 0: # TODO: support timeout with actual block count try: valid_gpu_block_ids = self._run_read_storage( - k_cache_keys, v_cache_keys, gpu_block_ids, cpu_block_ids + task.token_ids, + match_block_num, + k_cache_keys, + v_cache_keys, + gpu_block_ids, + cpu_block_ids, + task.timeout ) logger.info( - f"read_storage_task, finish loading {match_block_num} blocks from storage for task {task_id}." + f"Successfully read {match_block_num} blocks from storage for task {task.task_id}." ) except Exception as e: - logger.error(f"[rank {self.rank}/{self.n_ranks}] An error occurred: {task_id} {e}") + logger.error(f"Failed to read cache for task {task.task_id}, error: {e}") valid_gpu_block_ids = [] - result = (CacheStatus.STORAGE2GPU, task_id, keys, valid_gpu_block_ids) + result = (CacheStatus.STORAGE2GPU, task.task_id, task.keys, valid_gpu_block_ids) self.cache_task_queue.swap_storage_to_gpu_barrier.wait() self.cache_task_queue.swap_storage_to_gpu_barrier.reset() self.cache_task_queue.put_transfer_done_signal(result) logger.debug(f"read_storage_task: put_transfer_done_signal {result}") - logger.info( - f"read_storage_task: put_transfer_done_signal for transfer_task_id {task_id}, " - f"valid block num {len(valid_gpu_block_ids)}" - ) + except Exception as e: logger.error( - f"[rank {self.rank}/{self.n_ranks}] An error occurred in read_storage_task: " - f"task_id: {task_id}, error:{e}, {traceback.format_exc()}" + f"An error occurred in read_storage_task: " + f"task_id: {task.task_id}, error:{e}, {traceback.format_exc()}" ) - def _run_write_back_storage(self, k_cache_keys, v_cache_keys, gpu_block_ids, cpu_block_ids): + def _run_write_back_storage(self, token_ids, start_write_block_idx, k_cache_keys, v_cache_keys, gpu_block_ids, cpu_block_ids, timeout): try: - logger.debug( - f"_run_write_back_storage, k_cache_keys: {k_cache_keys}, v_cache_keys: {v_cache_keys}, " - f"gpu_block_ids: {gpu_block_ids}" - ) - key_cache_size = [ - self.key_cache_shape[0], - self.key_cache_shape[1], - self.key_cache_shape[2], - self.key_cache_shape[3], - ] - mode = 0 # gpu ==> cpu - swap_cache_layout( - self.gpu_cache_k_tensors, - self.storage_key_write_buffer, - key_cache_size, - gpu_block_ids, - cpu_block_ids, - self.device, - mode, - ) - swap_cache_layout( - self.gpu_cache_v_tensors, - self.storage_value_write_buffer, - key_cache_size, - gpu_block_ids, - cpu_block_ids, - self.device, - mode, - ) + if self.storage_backend_type == "mooncake": + k_cache_keys = k_cache_keys[start_write_block_idx:] + v_cache_keys = v_cache_keys[start_write_block_idx:] + gpu_block_ids = gpu_block_ids[start_write_block_idx:] + cpu_block_ids = cpu_block_ids[start_write_block_idx:] + + key_cache_size = [ + self.key_cache_shape[0], + self.key_cache_shape[1], + self.key_cache_shape[2], + self.key_cache_shape[3], + ] + mode = 0 # gpu ==> cpu + swap_cache_layout( + self.gpu_cache_k_tensors, + self.storage_key_write_buffer, + key_cache_size, + gpu_block_ids, + cpu_block_ids, + self.device, + mode, + ) + swap_cache_layout( + self.gpu_cache_v_tensors, + self.storage_value_write_buffer, + key_cache_size, + gpu_block_ids, + cpu_block_ids, + self.device, + mode, + ) + + block_num = len(gpu_block_ids) + keys = k_cache_keys + v_cache_keys + k_cache_ptrs = [ + self.storage_key_write_buffer + i * self.storage_buffer_stride_bytes for i in cpu_block_ids + ] + v_cache_ptrs = [ + self.storage_value_write_buffer + i * self.storage_buffer_stride_bytes for i in cpu_block_ids + ] + kv_cache_ptrs = k_cache_ptrs + v_cache_ptrs + kv_block_sizes = [self.storage_buffer_stride_bytes] * block_num * 2 # key and value + self.storage_backend.batch_set(keys, target_locations=kv_cache_ptrs, target_sizes=kv_block_sizes) + + elif self.storage_backend_type == "attention_store": + key_cache = [] + val_cache = [] + for i in range(self.num_layers + self.num_extra_layers): + key_cache.append(self.gpu_cache_kvs[f"key_caches_{i}_rank{self.rank}.device{self.device}"]) + val_cache.append(self.gpu_cache_kvs[f"value_caches_{i}_rank{self.rank}.device{self.device}"]) + self.storage_backend.write(key_cache, val_cache, token_ids, gpu_block_ids, start_write_block_idx, timeout) - block_num = len(gpu_block_ids) - keys = k_cache_keys + v_cache_keys - k_cache_ptrs = [ - self.storage_key_write_buffer + i * self.storage_buffer_stride_bytes for i in cpu_block_ids - ] - v_cache_ptrs = [ - self.storage_value_write_buffer + i * self.storage_buffer_stride_bytes for i in cpu_block_ids - ] - kv_cache_ptrs = k_cache_ptrs + v_cache_ptrs - kv_block_sizes = [self.storage_buffer_stride_bytes] * block_num * 2 # key and value - self.storage_backend.batch_set(keys, target_locations=kv_cache_ptrs, target_sizes=kv_block_sizes) except Exception as e: logger.error( - f"[rank {self.rank}/{self.n_ranks}] An error occurred in _run_write_back_storage: " - f"error:{e}, {traceback.format_exc()}" + f"An error occurred in _run_write_back_storage, " + f"error: {e}, traceback:\n{traceback.format_exc()}" ) - def write_back_storage_task(self, task_id, keys, gpu_block_ids, timeout=0.1): + def write_back_storage_task(self, task: WriteStorageTask): """ Write cache to the storage backend from the GPU memory. """ try: - logger.debug( - f"write cache to storage, keys: {keys}, gpu_block_ids: {gpu_block_ids}, " - f"task_id: {task_id}, timeout: {timeout}" - ) - - k_cache_keys = [f"{key}_key_{self.rank}" for key in keys] - v_cache_keys = [f"{key}_value_{self.rank}" for key in keys] - match_block_num = self._storage_exist_block_num(k_cache_keys, v_cache_keys) - - k_cache_keys = k_cache_keys[match_block_num:] - v_cache_keys = v_cache_keys[match_block_num:] - gpu_block_ids = gpu_block_ids[match_block_num:] + gpu_block_ids = task.gpu_block_ids.copy() cpu_block_ids = [i for i in range(len(gpu_block_ids))] - - if len(k_cache_keys) == 0: - logger.info(f"No uncached keys found for task {task_id}") + k_cache_keys = [f"{key}_key_{self.rank}" for key in task.keys] + v_cache_keys = [f"{key}_value_{self.rank}" for key in task.keys] + + match_block_num = 0 + if self.storage_backend_type == "mooncake": + match_block_num = self.storage_backend.query(k_cache_keys, v_cache_keys, task.timeout) + elif self.storage_backend_type == "attention_store": + match_block_num = self.storage_backend.query(task.token_ids, 0, task.timeout) + logger.info(f"Before writing cache from storage, found {match_block_num} blocks already cached in storage for task {task.task_id}") + + if match_block_num >= len(k_cache_keys): + logger.info(f"No uncached keys found for task {task.task_id}") gpu_block_ids = [] else: try: # TODO: support timeout with actual block count - self._run_write_back_storage(k_cache_keys, v_cache_keys, gpu_block_ids, cpu_block_ids) + self._run_write_back_storage( + task.token_ids, + match_block_num, + k_cache_keys, + v_cache_keys, + gpu_block_ids, + cpu_block_ids, + task.timeout + ) + logger.info(f"Successfully wrote cache to storage for task {task.task_id}") except Exception as e: logger.error(f"Error in write back storage task: {e}") gpu_block_ids = [] - result = (CacheStatus.GPU2STORAGE, task_id, keys, gpu_block_ids) + result = (CacheStatus.GPU2STORAGE, task.task_id, task.keys, gpu_block_ids) self.cache_task_queue.swap_to_storage_barrier.wait() if self.rank == 0: # 只有当rank为0时执行同步操作 self.cache_task_queue.swap_to_storage_barrier.reset() self.cache_task_queue.put_transfer_done_signal(result) # 发送传输完成信号 logger.debug(f"write_back_storage_task: put_transfer_done_signal {result}") - logger.info(f"write_back_storage_task: put_transfer_done_signal for transfer_task_id {task_id}") except Exception as e: logger.error( - f"[rank {self.rank}/{self.n_ranks}] An error occurred in write_back_storage_task: " - f"error:{e}, {traceback.format_exc()}" + f"An error occurred in write_back_storage_task, " + f"error: {e}, traceback:\n{traceback.format_exc()}" ) def _do_swap_to_cpu_task( @@ -734,12 +782,12 @@ def do_data_transfer(self): self.cache_task_queue.barrier1.reset() if self.cache_task_broadcast_signal.value[0] == 1: data, read_finish = self.cache_task_queue.get_transfer_task() - logger.debug(f"transfer data: get_transfer_task {data}") + logger.debug(f"do_data_transfer: {data}") if read_finish: self.cache_task_broadcast_signal.value[0] = 0 - event_type, transfer_task_id = data[0], data[1] + event_type, event_args = data[0], data[1:] if event_type.value == CacheStatus.SWAP2CPU.value: - swap_node_ids, gpu_block_id, cpu_block_id = data[2:] + transfer_task_id, swap_node_ids, gpu_block_id, cpu_block_id = event_args self.swap_to_cpu_thread_pool.submit( self._do_swap_to_cpu_task, swap_node_ids, @@ -749,7 +797,7 @@ def do_data_transfer(self): transfer_task_id, ) elif event_type.value == CacheStatus.SWAP2GPU.value: - swap_node_ids, gpu_block_id, cpu_block_id = data[2:] + transfer_task_id, swap_node_ids, gpu_block_id, cpu_block_id = event_args self.swap_to_gpu_thread_pool.submit( self._do_swap_to_gpu_task, swap_node_ids, @@ -759,22 +807,16 @@ def do_data_transfer(self): transfer_task_id, ) elif event_type.value == CacheStatus.STORAGE2GPU.value: - hash_keys, gpu_block_ids, timeout = data[2:] + read_storage_task = event_args[0] self.read_storage_thread_pool.submit( self.read_storage_task, - transfer_task_id, - hash_keys, - gpu_block_ids, - timeout, + read_storage_task, ) elif event_type.value == CacheStatus.GPU2STORAGE.value: - hash_keys, gpu_block_ids, timeout = data[2:] + write_storage_task = event_args[0] self.write_back_storage_thread_pool.submit( self.write_back_storage_task, - transfer_task_id, - hash_keys, - gpu_block_ids, - timeout, + write_storage_task, ) else: if self.n_ranks > 1: @@ -1038,7 +1080,11 @@ def main(): args = parse_args() rank_id = args.rank + args.local_data_parallel_id * args.mp_num - logger = get_logger("cache_transfer_manager", f"cache_transfer_manager_tprank{args.rank}.log") + if args.mp_num > 1: + logger = get_logger("cache_transfer", f"cache_transfer.log.{rank_id}") + else: + logger = get_logger("cache_transfer", f"cache_transfer.log") + logger.info(f"args: {vars(args)}") set_device(args.device_id) try: diff --git a/fastdeploy/cache_manager/prefix_cache_manager.py b/fastdeploy/cache_manager/prefix_cache_manager.py index 5a13c9703f1..47fa625a0e4 100644 --- a/fastdeploy/cache_manager/prefix_cache_manager.py +++ b/fastdeploy/cache_manager/prefix_cache_manager.py @@ -31,6 +31,7 @@ from fastdeploy import envs from fastdeploy.cache_manager.cache_data import BlockNode, CacheStatus from fastdeploy.cache_manager.cache_metrics import CacheMetrics +from fastdeploy.cache_manager.cache_tasks import ReadStorageTask, WriteStorageTask from fastdeploy.cache_manager.ops import get_all_visible_devices from fastdeploy.engine.request import Request from fastdeploy.inter_communicator import EngineCacheQueue, IPCSignal, PrefixTreeStatus @@ -207,7 +208,6 @@ def launch_cache_manager( key_cache_shape, val_cache_shape = self._get_kv_cache_shape(cache_config.total_block_num) key_cache_shape = ",".join([str(i) for i in key_cache_shape]) val_cache_shape = ",".join([str(i) for i in val_cache_shape]) - logger.info(f"key_cache_shape {key_cache_shape} value_cache_shape {val_cache_shape}") if self.enable_splitwise: cache_messager_processes = self.launch_cache_messager( cache_config, @@ -806,7 +806,14 @@ def request_match_storage_blocks(self, request, extra_gpu_block_ids): prefix_block_key = [cur_block_key] logger.info(f"start prefetch cache from storage, req_id: {req_id}, block num: {len(block_keys)}") - matched_block_ids = self.issue_prefetch_storage_task(req_id, block_keys, extra_gpu_block_ids) + task = ReadStorageTask( + task_id=req_id, + keys=block_keys, + token_ids=input_ids, + gpu_block_ids=extra_gpu_block_ids, + start_read_block_idx=num_cached_tokens // block_size, + ) + matched_block_ids = self.issue_prefetch_storage_task(task, is_sync=True) logger.info( f"finish prefetch cache from storage, req_id: {req_id}, matched block num: {len(matched_block_ids)}" ) @@ -990,24 +997,30 @@ def write_cache_to_storage(self, request: Request): gpu_block_ids = request.block_tables[: len(keys)] logger.info(f"start write cache back to storage, req_id: {req_id}, block num: {len(keys)}") + task = WriteStorageTask( + task_id=req_id, + keys=keys, + token_ids=request.prompt_token_ids + request.output_token_ids, + gpu_block_ids=gpu_block_ids, + ) tic = time.time() - self.issue_write_back_storage_task(req_id=req_id, hash_keys=keys, gpu_block_ids=gpu_block_ids, is_sync=True) + self.issue_write_back_storage_task(task, is_sync=True) cost_time = time.time() - tic logger.info(f"finish write cache back to storage, req_id: {req_id}, cost_time: {cost_time:.6f}s") - def issue_write_back_storage_task(self, req_id, hash_keys, gpu_block_ids, is_sync=True, timeout=0.5): + def issue_write_back_storage_task(self, task: WriteStorageTask, is_sync=True): if self.kvcache_storage_backend is None: return - if len(hash_keys) != len(gpu_block_ids): - err_msg = f"write_back_storage error: hash_keys({len(hash_keys)}) != gpu_block_ids({len(gpu_block_ids)})" + if len(task.keys) != len(task.gpu_block_ids): + err_msg = f"write_back_storage error: hash_keys({len(task.keys)}) != gpu_block_ids({len(task.gpu_block_ids)})" logger.error(err_msg) raise ValueError(err_msg) - self.task_write_back_event[req_id] = Event() - self.cache_task_queue.put_transfer_task((CacheStatus.GPU2STORAGE, req_id, hash_keys, gpu_block_ids, timeout)) + self.task_write_back_event[task.task_id] = Event() + self.cache_task_queue.put_transfer_task((CacheStatus.GPU2STORAGE, task)) if is_sync: - self.wait_write_storage_task(req_id) + self.wait_write_storage_task(task.task_id) def wait_write_storage_task(self, req_id): """ @@ -1017,16 +1030,19 @@ def wait_write_storage_task(self, req_id): self.task_write_back_event[req_id].wait() del self.task_write_back_event[req_id] - def issue_prefetch_storage_task(self, req_id, hash_keys, gpu_block_ids, is_sync=True, timeout=0.5): + def issue_prefetch_storage_task(self, task: ReadStorageTask, is_sync=True): """ Prefetch cache from storage task """ + if self.kvcache_storage_backend is None: + return [] + storage_block_ids = [] - self.task_prefetch_event[req_id] = Event() + self.task_prefetch_event[task.task_id] = Event() # issue task to cache_transfer_manager - self.cache_task_queue.put_transfer_task((CacheStatus.STORAGE2GPU, req_id, hash_keys, gpu_block_ids, timeout)) + self.cache_task_queue.put_transfer_task((CacheStatus.STORAGE2GPU, task)) if is_sync: - storage_block_ids = self.wait_prefetch_storage_task(req_id) + storage_block_ids = self.wait_prefetch_storage_task(task.task_id) return storage_block_ids def wait_prefetch_storage_task(self, req_id): diff --git a/fastdeploy/cache_manager/transfer_factory/__init__.py b/fastdeploy/cache_manager/transfer_factory/__init__.py index f9c7d5dc979..5bcf74671d5 100644 --- a/fastdeploy/cache_manager/transfer_factory/__init__.py +++ b/fastdeploy/cache_manager/transfer_factory/__init__.py @@ -17,7 +17,7 @@ from fastdeploy.platforms import current_platform from .kvcache_storage import KVCacheStorage -from .mooncake_store import MooncakeStore +from .mooncake_store import MooncakeStore, AttentionStore from .rdma_cache_transfer import RDMACommManager if current_platform.is_cuda(): @@ -31,4 +31,5 @@ "RDMACommManager", "KVCacheStorage", "MooncakeStore", + "AttentionStore", ] diff --git a/fastdeploy/cache_manager/transfer_factory/kvcache_storage.py b/fastdeploy/cache_manager/transfer_factory/kvcache_storage.py index 2345c061aaf..2428babf4de 100644 --- a/fastdeploy/cache_manager/transfer_factory/kvcache_storage.py +++ b/fastdeploy/cache_manager/transfer_factory/kvcache_storage.py @@ -95,3 +95,10 @@ def clear(self) -> bool: Clear all keys in storage """ pass + + @abstractmethod + def query(self) -> int: + """ + Query the number of blocks stored in the storage. + """ + pass diff --git a/fastdeploy/cache_manager/transfer_factory/mooncake_store/__init__.py b/fastdeploy/cache_manager/transfer_factory/mooncake_store/__init__.py index 1de4084be6b..643bc686a86 100644 --- a/fastdeploy/cache_manager/transfer_factory/mooncake_store/__init__.py +++ b/fastdeploy/cache_manager/transfer_factory/mooncake_store/__init__.py @@ -15,5 +15,6 @@ """ from .mooncake_store import MooncakeStore +from .attention_store import AttentionStore -__all__ = ["MooncakeStore"] +__all__ = ["MooncakeStore", "AttentionStore"] diff --git a/fastdeploy/cache_manager/transfer_factory/mooncake_store/attention_store.py b/fastdeploy/cache_manager/transfer_factory/mooncake_store/attention_store.py new file mode 100644 index 00000000000..d7bb3bbd48b --- /dev/null +++ b/fastdeploy/cache_manager/transfer_factory/mooncake_store/attention_store.py @@ -0,0 +1,191 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +import ctypes +import paddle +import json +import os +import time +import traceback +import uuid +from dataclasses import dataclass, fields +from typing import Any, List, Optional + +import attentionstore_sdk.api.common.common_pb2 as common_pb2 +from attentionstore_sdk.sdk import AttentionStoreSDK, Tokens +from attentionstore_sdk.utils.err import AttentionStoreSDKError +from fastdeploy.cache_manager.transfer_factory.kvcache_storage import ( + KVCacheStorage, + logger, +) + +@dataclass +class AttentionStoreConfig: + namespace: str = "default_ns" + pod_name: str = "default_pod" + model_version: str = "v0" + shard_id: int = 0 + shard_num: int = 1 + layer_num: int = 1 + block_token_size: int = 64 + bytes_per_shard_layer_per_block: int = 1024 + device_id: int = 0 + dp_id: int = 0 + +class AttentionStore(KVCacheStorage): + def __init__(self, **args): + + try: + import attentionstore_sdk.api.common.common_pb2 as common_pb2 + from attentionstore_sdk.sdk import AttentionStoreSDK, Tokens + except ImportError as e: + raise ImportError( + "Please install attentionstore_sdk to run Fastdeploy with attentionstore_sdk." + ) from e + + self.config = AttentionStoreConfig(**args) + + try: + logger.info(f"Start initializing AttentionStoreSDK with config: {self.config}") + self.sdk = AttentionStoreSDK( + self.config.namespace, + self.config.pod_name, + self.config.model_version, + self.config.shard_id, + self.config.shard_num, + self.config.layer_num, + self.config.block_token_size, + self.config.bytes_per_shard_layer_per_block, + self.config.device_id, + self.config.dp_id, + ) + self.wait_for_sdk_ready(timeout=300, delta_t=5) + logger.info(f"✅ AttentionStoreSDK is inititialized successfully!") + except Exception as e: + logger.error(f"❌ AttentionStoreSDK initialization failed, error: {e}, traceback: {traceback.format_exc()}" + f"\nconfig: {self.config}") + + def wait_for_sdk_ready(self, timeout: float, delta_t: float): + t = 0 + while t < timeout: + try: + tokens = Tokens(list(range(self.config.block_token_size + 1)), self.config.block_token_size) + self.sdk.match(tokens, 0, delta_t) + return + except AttentionStoreSDKError as e: + if "cuda memory not ready" in str(e): + logger.debug(f"wait_for_sdk_ready: cuda memory not ready, try again..") + time.sleep(delta_t) + continue + else: + raise RuntimeError(f"Unexpected exception during AttentionStoreSDK initialization: {e}\n{traceback.format_exc()}") + finally: + t += delta_t + raise TimeoutError(f"AttentionStoreSDK initialization timed out after {timeout} seconds") + + def read( + self, + key_cache: List[paddle.Tensor], + val_cache: List[paddle.Tensor], + token_ids: List[int], + gpu_block_ids: List[int], + start_read_block_idx: int, + timeout: float = 30.0, + ): + logger.debug(f"read: token_ids={token_ids} gpu_block_ids={gpu_block_ids} start_read_block_idx={start_read_block_idx} timeout={timeout}") + tokens = Tokens(token_ids, self.config.block_token_size) + k_data_ptrs = [k.data_ptr()for k in key_cache] + v_data_ptrs = [v.data_ptr()for v in val_cache] + num = 0 + try: + num = self.sdk.read( + list(range(self.config.layer_num)), + tokens, + start_read_block_idx, + k_data_ptrs, + v_data_ptrs, + gpu_block_ids, + timeout, + ) + logger.debug(f"read: successfully read {num} blocks") + except AttentionStoreSDKError as e: + logger.error(f"Failed to execute AttentionStoreSDK read, error: {e}, traceback:\n{traceback.format_exc()}") + return num + + def write( + self, + key_cache: List[paddle.Tensor], + val_cache: List[paddle.Tensor], + token_ids: List[int], + gpu_block_ids: List[int], + start_write_block_idx: int, + timeout: float = 30.0, + ) -> int: + logger.debug(f"write: token_ids={token_ids} gpu_block_ids={gpu_block_ids} start_write_block_idx={start_write_block_idx} timeout={timeout}") + tokens = Tokens(token_ids, self.config.block_token_size) + k_data_ptrs = [k.data_ptr()for k in key_cache] + v_data_ptrs = [v.data_ptr()for v in val_cache] + num = 0 + try: + num = self.sdk.write( + list(range(self.config.layer_num)), + tokens, + start_write_block_idx, + k_data_ptrs, + v_data_ptrs, + gpu_block_ids, + timeout, + ) + logger.debug(f"write: successfully wrote {num} blocks") + except AttentionStoreSDKError as e: + logger.error(f"Failed to execute AttentionStoreSDK write, error: {e}, traceback:\n{traceback.format_exc()}") + return num + + def query(self, token_ids: List[int], start_match_block_idx: int, timeout: float = 10.0): + """ + Given the input ids and starting index to match, get the valid blocks number that + can be prefetched from storage backend. + """ + logger.debug(f"query: token_ids={token_ids} start_match_block_idx={start_match_block_idx} timeout={timeout}") + tokens = Tokens(token_ids, self.config.block_token_size) + num = 0 + try: + num = self.sdk.match(tokens, start_match_block_idx, timeout) + logger.debug(f"query: successfully matched {num} blocks") + except AttentionStoreSDKError as e: + logger.error(f"Failed to execute AttentionStoreSDK match, error: {e}, traceback:\n{traceback.format_exc()}") + return num + + def get(self, **kwargs): + raise NotImplementedError(f"AttentionStore does not support this method") + + def batch_get(self, **kwargs): + raise NotImplementedError(f"AttentionStore does not support this method") + + def set(self, **kwargs) -> bool: + raise NotImplementedError(f"AttentionStore does not support this method") + + def batch_set(self, **kwargs) -> bool: + raise NotImplementedError(f"AttentionStore does not support this method") + + def exists(self, keys: List[str]) -> bool: + raise NotImplementedError(f"AttentionStore does not support this method") + + def clear(self) -> bool: + raise NotImplementedError(f"AttentionStore does not support this method") + + def register_buffer(self, buffer_ptr, buffer_size, buffer_type="none_type") -> None: + raise NotImplementedError(f"AttentionStore does not support this method") \ No newline at end of file diff --git a/fastdeploy/cache_manager/transfer_factory/mooncake_store/mooncake_store.py b/fastdeploy/cache_manager/transfer_factory/mooncake_store/mooncake_store.py index 3311bc3f256..dbaf9ffc315 100644 --- a/fastdeploy/cache_manager/transfer_factory/mooncake_store/mooncake_store.py +++ b/fastdeploy/cache_manager/transfer_factory/mooncake_store/mooncake_store.py @@ -235,6 +235,21 @@ def exists(self, keys: List[str]): logger.debug(f"The exists fun processes {len(keys)} objects, cost_time: {cost_time:.3f}ms") return result + def query(self, k_keys: List[str], v_keys: List[str], timeout: float = 1.0): + """ + Given the k_keys and v_keys, get the valid blocks number that + can be prefetched from storage backend. + """ + assert len(k_keys) == len(v_keys), "k_keys and v_keys must have the same length." + result = self.exists(k_keys + v_keys) + + # only consider the case when both key and value exist + num = 0 + for k, v in zip(k_keys, v_keys): + if result[k] and result[v]: + num += 1 + return num + def delete(self, key, timeout=5) -> bool: while timeout: result = self.store.remove(key) diff --git a/fastdeploy/engine/args_utils.py b/fastdeploy/engine/args_utils.py index 0318868f897..ca98a44a12a 100644 --- a/fastdeploy/engine/args_utils.py +++ b/fastdeploy/engine/args_utils.py @@ -601,7 +601,6 @@ def post_init_ports(name: str, ports: list, num_total_ports: int): for port in ports: assert is_port_available("0.0.0.0", port), f"Parameter `{name}`:{port} is already in use." - console_logger.debug(f"post init {name}: {ports}") return ports num_nodes = len(self.ips) if self.ips else 1 @@ -1037,7 +1036,7 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: cache_group.add_argument( "--kvcache-storage-backend", type=nullable_str, - choices=["mooncake"], + choices=["mooncake", "attention_store"], default=EngineArgs.kvcache_storage_backend, help="The storage backend for kvcache storage. Leave empty to disable.", ) diff --git a/fastdeploy/engine/sched/resource_manager_v1.py b/fastdeploy/engine/sched/resource_manager_v1.py index b6da9a99224..777b559748c 100644 --- a/fastdeploy/engine/sched/resource_manager_v1.py +++ b/fastdeploy/engine/sched/resource_manager_v1.py @@ -961,10 +961,10 @@ def get_storage_cached_blocks(self, request: Request, extra_gpu_block_ids: list try: tic = time.time() req_id = request.request_id - llm_logger.debug(f"get_storage_cached_blocks start process req {req_id}") + llm_logger.debug(f"get_storage_cached_blocks: start process req {req_id}") matched_block_ids = self.cache_manager.request_match_storage_blocks(request, extra_gpu_block_ids) llm_logger.debug( - f"matched {len(matched_block_ids)} blocks from storage for req_id:{req_id}, " + f"get_storage_cached_blocks: matched {len(matched_block_ids)} blocks from storage for req_id: {req_id}, " f"cost_time: {time.time() - tic:.6f}s" ) @@ -1215,8 +1215,11 @@ def clear_data(self): def update_metrics(self): # Update metrics num_tasks = sum([1 if task else 0 for task in self.tasks_list]) - num_blocks_used_by_tasks = sum([len(task.block_tables) if task else 0 for task in self.tasks_list]) - main_process_metrics.available_gpu_block_num.set(self.total_block_number() - num_blocks_used_by_tasks) + blocks_used_by_tasks = set() + for task in self.tasks_list: + if task is not None: + blocks_used_by_tasks.union(task.block_tables) + main_process_metrics.available_gpu_block_num.set(self.total_block_number() - len(blocks_used_by_tasks)) main_process_metrics.batch_size.set(self.max_num_seqs - self.available_batch()) main_process_metrics.gpu_cache_usage_perc.set(self.get_gpu_cache_usage_perc()) main_process_metrics.num_requests_running.set(len(self.running)) From 8261f6d6f3200150d38fe30df1bfbde2a27116c6 Mon Sep 17 00:00:00 2001 From: liyonghua0910 Date: Mon, 29 Dec 2025 12:57:06 +0000 Subject: [PATCH 02/16] [fix] fix codestyle --- fastdeploy/cache_manager/cache_messager.py | 2 +- fastdeploy/cache_manager/cache_tasks.py | 8 +- .../cache_manager/cache_transfer_manager.py | 103 ++++++++++-------- .../cache_manager/prefix_cache_manager.py | 12 +- .../transfer_factory/__init__.py | 2 +- .../mooncake_store/__init__.py | 2 +- .../mooncake_store/attention_store.py | 100 +++++++++-------- 7 files changed, 129 insertions(+), 100 deletions(-) diff --git a/fastdeploy/cache_manager/cache_messager.py b/fastdeploy/cache_manager/cache_messager.py index de877f61e4c..53c6a883550 100644 --- a/fastdeploy/cache_manager/cache_messager.py +++ b/fastdeploy/cache_manager/cache_messager.py @@ -1032,7 +1032,7 @@ def main(): if args.mp_num > 1: logger = get_logger("cache_messager", f"cache_messager.log.{rank_id}") else: - logger = get_logger("cache_messager", f"cache_messager.log") + logger = get_logger("cache_messager", "cache_messager.log") logger.info("create cache messager...") logger.info(f"{args}") diff --git a/fastdeploy/cache_manager/cache_tasks.py b/fastdeploy/cache_manager/cache_tasks.py index 9d281a98f4b..43ffb6a682f 100644 --- a/fastdeploy/cache_manager/cache_tasks.py +++ b/fastdeploy/cache_manager/cache_tasks.py @@ -1,6 +1,6 @@ - from dataclasses import dataclass -from typing import List, Optional +from typing import List + @dataclass(frozen=True, kw_only=True) class CacheTask: @@ -9,11 +9,13 @@ class CacheTask: token_ids: List[int] gpu_block_ids: List[int] + @dataclass(frozen=True, kw_only=True) class ReadStorageTask(CacheTask): start_read_block_idx: int timeout: float = 30.0 + @dataclass(frozen=True, kw_only=True) class WriteStorageTask(CacheTask): - timeout: float = 30.0 \ No newline at end of file + timeout: float = 30.0 diff --git a/fastdeploy/cache_manager/cache_transfer_manager.py b/fastdeploy/cache_manager/cache_transfer_manager.py index 7135484f113..00e95aabc9e 100644 --- a/fastdeploy/cache_manager/cache_transfer_manager.py +++ b/fastdeploy/cache_manager/cache_transfer_manager.py @@ -22,15 +22,15 @@ import threading import time import traceback -from typing import List from math import prod +from typing import List import numpy as np import paddle from fastdeploy import envs -from fastdeploy.cache_manager.cache_tasks import ReadStorageTask, WriteStorageTask from fastdeploy.cache_manager.cache_data import CacheStatus +from fastdeploy.cache_manager.cache_tasks import ReadStorageTask, WriteStorageTask from fastdeploy.cache_manager.ops import ( cuda_host_alloc, cuda_host_free, @@ -42,7 +42,7 @@ swap_cache_layout, unset_data_ipc, ) -from fastdeploy.cache_manager.transfer_factory import MooncakeStore, AttentionStore +from fastdeploy.cache_manager.transfer_factory import AttentionStore, MooncakeStore from fastdeploy.config import SpeculativeConfig from fastdeploy.inter_communicator import EngineCacheQueue, IPCSignal, KVCacheStatus from fastdeploy.platforms import current_platform @@ -142,19 +142,19 @@ def __init__(self, args): self.gpu_cache_scales_k_tensors = [] self.gpu_cache_scales_v_tensors = [] self.speculative_config = SpeculativeConfig(args.speculative_config) - + # parse kv cache shape self.key_cache_shape = [int(i) for i in args.key_cache_shape.split(",")] self.value_cache_shape = [] if args.value_cache_shape: self.value_cache_shape = [int(i) for i in args.value_cache_shape.split(",")] - + # extract kv cache shape into fields self.num_gpu_blocks = self.key_cache_shape[0] self.head_num = self.key_cache_shape[1] self.block_size = self.key_cache_shape[2] self.head_dim = self.key_cache_shape[3] - + # extract other arg values self.n_ranks = args.mp_num self.rank = args.rank @@ -240,7 +240,7 @@ def __init__(self, args): shard_id=self.rank, shard_num=self.n_ranks, layer_num=self.num_layers + self.num_extra_layers, - block_token_size=self.key_cache_shape[2], # key_cache_shape: [num_blocks, head_num, block_size, head_dim] + block_token_size=self.block_size, bytes_per_shard_layer_per_block=prod(self.key_cache_shape[1:]), device_id=self.device, dp_id=self.local_data_parallel_id, @@ -463,13 +463,13 @@ def _get_cache_bytes(self, cache_dtype): return cache_bytes def _run_read_storage( - self, - token_ids: List[int], - start_read_block_idx: int, - k_cache_keys: List[str], - v_cache_keys: List[str], - gpu_block_ids: List[int], - cpu_block_ids: List[int], + self, + token_ids: List[int], + start_read_block_idx: int, + k_cache_keys: List[str], + v_cache_keys: List[str], + gpu_block_ids: List[int], + cpu_block_ids: List[int], timeout: float, ): """ @@ -481,14 +481,20 @@ def _run_read_storage( v_cache_keys = v_cache_keys[:start_read_block_idx] gpu_block_ids = gpu_block_ids[:start_read_block_idx] cpu_block_ids = cpu_block_ids[:start_read_block_idx] - + block_num = len(gpu_block_ids) keys = k_cache_keys + v_cache_keys - k_cache_ptrs = [self.storage_key_read_buffer + i * self.storage_buffer_stride_bytes for i in cpu_block_ids] - v_cache_ptrs = [self.storage_value_read_buffer + i * self.storage_buffer_stride_bytes for i in cpu_block_ids] + k_cache_ptrs = [ + self.storage_key_read_buffer + i * self.storage_buffer_stride_bytes for i in cpu_block_ids + ] + v_cache_ptrs = [ + self.storage_value_read_buffer + i * self.storage_buffer_stride_bytes for i in cpu_block_ids + ] kv_cache_ptrs = k_cache_ptrs + v_cache_ptrs kv_block_sizes = [self.storage_buffer_stride_bytes] * block_num * 2 # key and value - result = self.storage_backend.batch_get(keys, target_locations=kv_cache_ptrs, target_sizes=kv_block_sizes) # TODO(liyonghua): impl for attention store + result = self.storage_backend.batch_get( + keys, target_locations=kv_cache_ptrs, target_sizes=kv_block_sizes + ) # TODO(liyonghua): impl for attention store k_result, v_result = result[:block_num], result[block_num:] success_block_num = 0 @@ -525,15 +531,16 @@ def _run_read_storage( for i in range(self.num_layers + self.num_extra_layers): key_cache.append(self.gpu_cache_kvs[f"key_caches_{i}_rank{self.rank}.device{self.device}"]) val_cache.append(self.gpu_cache_kvs[f"value_caches_{i}_rank{self.rank}.device{self.device}"]) - read_block_num = self.storage_backend.read(key_cache, val_cache, token_ids, gpu_block_ids, start_read_block_idx, timeout) + read_block_num = self.storage_backend.read( + key_cache, val_cache, token_ids, gpu_block_ids, start_read_block_idx, timeout + ) valid_gpu_block_ids = gpu_block_ids[:read_block_num] return valid_gpu_block_ids except Exception as e: logger.error( - f"An error occurred in _run_read_storage, " - f"error: {e}, traceback:\n{traceback.format_exc()}" + f"An error occurred in _run_read_storage, " f"error: {e}, traceback:\n{traceback.format_exc()}" ) raise @@ -550,24 +557,24 @@ def read_storage_task(self, task: ReadStorageTask): match_block_num = self.storage_backend.query(k_cache_keys, v_cache_keys) elif self.storage_backend_type == "attention_store": match_block_num = self.storage_backend.query(task.token_ids, task.start_read_block_idx, task.timeout) - logger.info(f"Before reading cache from storage, found {match_block_num} blocks already cached in storage for task {task.task_id}") + logger.info( + f"Before reading cache from storage, found {match_block_num} blocks already cached in storage for task {task.task_id}" + ) valid_gpu_block_ids = [] if match_block_num > 0: # TODO: support timeout with actual block count try: valid_gpu_block_ids = self._run_read_storage( - task.token_ids, - match_block_num, - k_cache_keys, - v_cache_keys, + task.token_ids, + match_block_num, + k_cache_keys, + v_cache_keys, gpu_block_ids, cpu_block_ids, - task.timeout - ) - logger.info( - f"Successfully read {match_block_num} blocks from storage for task {task.task_id}." + task.timeout, ) + logger.info(f"Successfully read {match_block_num} blocks from storage for task {task.task_id}.") except Exception as e: logger.error(f"Failed to read cache for task {task.task_id}, error: {e}") valid_gpu_block_ids = [] @@ -577,14 +584,16 @@ def read_storage_task(self, task: ReadStorageTask): self.cache_task_queue.swap_storage_to_gpu_barrier.reset() self.cache_task_queue.put_transfer_done_signal(result) logger.debug(f"read_storage_task: put_transfer_done_signal {result}") - + except Exception as e: logger.error( f"An error occurred in read_storage_task: " f"task_id: {task.task_id}, error:{e}, {traceback.format_exc()}" ) - def _run_write_back_storage(self, token_ids, start_write_block_idx, k_cache_keys, v_cache_keys, gpu_block_ids, cpu_block_ids, timeout): + def _run_write_back_storage( + self, token_ids, start_write_block_idx, k_cache_keys, v_cache_keys, gpu_block_ids, cpu_block_ids, timeout + ): try: if self.storage_backend_type == "mooncake": k_cache_keys = k_cache_keys[start_write_block_idx:] @@ -636,12 +645,13 @@ def _run_write_back_storage(self, token_ids, start_write_block_idx, k_cache_keys for i in range(self.num_layers + self.num_extra_layers): key_cache.append(self.gpu_cache_kvs[f"key_caches_{i}_rank{self.rank}.device{self.device}"]) val_cache.append(self.gpu_cache_kvs[f"value_caches_{i}_rank{self.rank}.device{self.device}"]) - self.storage_backend.write(key_cache, val_cache, token_ids, gpu_block_ids, start_write_block_idx, timeout) + self.storage_backend.write( + key_cache, val_cache, token_ids, gpu_block_ids, start_write_block_idx, timeout + ) except Exception as e: logger.error( - f"An error occurred in _run_write_back_storage, " - f"error: {e}, traceback:\n{traceback.format_exc()}" + f"An error occurred in _run_write_back_storage, " f"error: {e}, traceback:\n{traceback.format_exc()}" ) def write_back_storage_task(self, task: WriteStorageTask): @@ -653,14 +663,16 @@ def write_back_storage_task(self, task: WriteStorageTask): cpu_block_ids = [i for i in range(len(gpu_block_ids))] k_cache_keys = [f"{key}_key_{self.rank}" for key in task.keys] v_cache_keys = [f"{key}_value_{self.rank}" for key in task.keys] - + match_block_num = 0 if self.storage_backend_type == "mooncake": match_block_num = self.storage_backend.query(k_cache_keys, v_cache_keys, task.timeout) elif self.storage_backend_type == "attention_store": match_block_num = self.storage_backend.query(task.token_ids, 0, task.timeout) - logger.info(f"Before writing cache from storage, found {match_block_num} blocks already cached in storage for task {task.task_id}") - + logger.info( + f"Before writing cache from storage, found {match_block_num} blocks already cached in storage for task {task.task_id}" + ) + if match_block_num >= len(k_cache_keys): logger.info(f"No uncached keys found for task {task.task_id}") gpu_block_ids = [] @@ -668,13 +680,13 @@ def write_back_storage_task(self, task: WriteStorageTask): try: # TODO: support timeout with actual block count self._run_write_back_storage( - task.token_ids, - match_block_num, - k_cache_keys, - v_cache_keys, + task.token_ids, + match_block_num, + k_cache_keys, + v_cache_keys, gpu_block_ids, cpu_block_ids, - task.timeout + task.timeout, ) logger.info(f"Successfully wrote cache to storage for task {task.task_id}") except Exception as e: @@ -689,8 +701,7 @@ def write_back_storage_task(self, task: WriteStorageTask): logger.debug(f"write_back_storage_task: put_transfer_done_signal {result}") except Exception as e: logger.error( - f"An error occurred in write_back_storage_task, " - f"error: {e}, traceback:\n{traceback.format_exc()}" + f"An error occurred in write_back_storage_task, " f"error: {e}, traceback:\n{traceback.format_exc()}" ) def _do_swap_to_cpu_task( @@ -1083,7 +1094,7 @@ def main(): if args.mp_num > 1: logger = get_logger("cache_transfer", f"cache_transfer.log.{rank_id}") else: - logger = get_logger("cache_transfer", f"cache_transfer.log") + logger = get_logger("cache_transfer", "cache_transfer.log") logger.info(f"args: {vars(args)}") set_device(args.device_id) diff --git a/fastdeploy/cache_manager/prefix_cache_manager.py b/fastdeploy/cache_manager/prefix_cache_manager.py index 47fa625a0e4..c85aab30e7c 100644 --- a/fastdeploy/cache_manager/prefix_cache_manager.py +++ b/fastdeploy/cache_manager/prefix_cache_manager.py @@ -807,11 +807,11 @@ def request_match_storage_blocks(self, request, extra_gpu_block_ids): logger.info(f"start prefetch cache from storage, req_id: {req_id}, block num: {len(block_keys)}") task = ReadStorageTask( - task_id=req_id, - keys=block_keys, - token_ids=input_ids, + task_id=req_id, + keys=block_keys, + token_ids=input_ids, gpu_block_ids=extra_gpu_block_ids, - start_read_block_idx=num_cached_tokens // block_size, + start_read_block_idx=num_cached_tokens // block_size, ) matched_block_ids = self.issue_prefetch_storage_task(task, is_sync=True) logger.info( @@ -1013,7 +1013,9 @@ def issue_write_back_storage_task(self, task: WriteStorageTask, is_sync=True): return if len(task.keys) != len(task.gpu_block_ids): - err_msg = f"write_back_storage error: hash_keys({len(task.keys)}) != gpu_block_ids({len(task.gpu_block_ids)})" + err_msg = ( + f"write_back_storage error: hash_keys({len(task.keys)}) != gpu_block_ids({len(task.gpu_block_ids)})" + ) logger.error(err_msg) raise ValueError(err_msg) diff --git a/fastdeploy/cache_manager/transfer_factory/__init__.py b/fastdeploy/cache_manager/transfer_factory/__init__.py index 5bcf74671d5..34561f15ba4 100644 --- a/fastdeploy/cache_manager/transfer_factory/__init__.py +++ b/fastdeploy/cache_manager/transfer_factory/__init__.py @@ -17,7 +17,7 @@ from fastdeploy.platforms import current_platform from .kvcache_storage import KVCacheStorage -from .mooncake_store import MooncakeStore, AttentionStore +from .mooncake_store import AttentionStore, MooncakeStore from .rdma_cache_transfer import RDMACommManager if current_platform.is_cuda(): diff --git a/fastdeploy/cache_manager/transfer_factory/mooncake_store/__init__.py b/fastdeploy/cache_manager/transfer_factory/mooncake_store/__init__.py index 643bc686a86..00cbd4acf7b 100644 --- a/fastdeploy/cache_manager/transfer_factory/mooncake_store/__init__.py +++ b/fastdeploy/cache_manager/transfer_factory/mooncake_store/__init__.py @@ -14,7 +14,7 @@ # limitations under the License. """ -from .mooncake_store import MooncakeStore from .attention_store import AttentionStore +from .mooncake_store import MooncakeStore __all__ = ["MooncakeStore", "AttentionStore"] diff --git a/fastdeploy/cache_manager/transfer_factory/mooncake_store/attention_store.py b/fastdeploy/cache_manager/transfer_factory/mooncake_store/attention_store.py index d7bb3bbd48b..a833f98a869 100644 --- a/fastdeploy/cache_manager/transfer_factory/mooncake_store/attention_store.py +++ b/fastdeploy/cache_manager/transfer_factory/mooncake_store/attention_store.py @@ -14,24 +14,30 @@ # limitations under the License. """ -import ctypes -import paddle -import json -import os import time import traceback -import uuid -from dataclasses import dataclass, fields -from typing import Any, List, Optional +from dataclasses import dataclass +from typing import List + +import paddle -import attentionstore_sdk.api.common.common_pb2 as common_pb2 -from attentionstore_sdk.sdk import AttentionStoreSDK, Tokens -from attentionstore_sdk.utils.err import AttentionStoreSDKError from fastdeploy.cache_manager.transfer_factory.kvcache_storage import ( KVCacheStorage, logger, ) +try: + from attentionstore_sdk.sdk import AttentionStoreSDK, Tokens + from attentionstore_sdk.utils.err import AttentionStoreSDKError + + _ATTENTIONSTORE_AVAILABLE = True +except Exception: + AttentionStoreSDK = None + Tokens = None + AttentionStoreSDKError = None + _ATTENTIONSTORE_AVAILABLE = False + + @dataclass class AttentionStoreConfig: namespace: str = "default_ns" @@ -45,16 +51,12 @@ class AttentionStoreConfig: device_id: int = 0 dp_id: int = 0 + class AttentionStore(KVCacheStorage): def __init__(self, **args): - try: - import attentionstore_sdk.api.common.common_pb2 as common_pb2 - from attentionstore_sdk.sdk import AttentionStoreSDK, Tokens - except ImportError as e: - raise ImportError( - "Please install attentionstore_sdk to run Fastdeploy with attentionstore_sdk." - ) from e + if not _ATTENTIONSTORE_AVAILABLE: + raise ImportError("Please install attentionstore_sdk to run Fastdeploy with attentionstore_sdk.") self.config = AttentionStoreConfig(**args) @@ -73,10 +75,12 @@ def __init__(self, **args): self.config.dp_id, ) self.wait_for_sdk_ready(timeout=300, delta_t=5) - logger.info(f"✅ AttentionStoreSDK is inititialized successfully!") + logger.info("✅ AttentionStoreSDK is inititialized successfully!") except Exception as e: - logger.error(f"❌ AttentionStoreSDK initialization failed, error: {e}, traceback: {traceback.format_exc()}" - f"\nconfig: {self.config}") + logger.error( + f"❌ AttentionStoreSDK initialization failed, error: {e}, traceback: {traceback.format_exc()}" + f"\nconfig: {self.config}" + ) def wait_for_sdk_ready(self, timeout: float, delta_t: float): t = 0 @@ -87,28 +91,32 @@ def wait_for_sdk_ready(self, timeout: float, delta_t: float): return except AttentionStoreSDKError as e: if "cuda memory not ready" in str(e): - logger.debug(f"wait_for_sdk_ready: cuda memory not ready, try again..") + logger.debug("wait_for_sdk_ready: cuda memory not ready, try again..") time.sleep(delta_t) continue else: - raise RuntimeError(f"Unexpected exception during AttentionStoreSDK initialization: {e}\n{traceback.format_exc()}") + raise RuntimeError( + f"Unexpected exception during AttentionStoreSDK initialization: {e}\n{traceback.format_exc()}" + ) finally: t += delta_t raise TimeoutError(f"AttentionStoreSDK initialization timed out after {timeout} seconds") def read( self, - key_cache: List[paddle.Tensor], - val_cache: List[paddle.Tensor], + key_cache: List[paddle.Tensor], + val_cache: List[paddle.Tensor], token_ids: List[int], gpu_block_ids: List[int], start_read_block_idx: int, timeout: float = 30.0, ): - logger.debug(f"read: token_ids={token_ids} gpu_block_ids={gpu_block_ids} start_read_block_idx={start_read_block_idx} timeout={timeout}") + logger.debug( + f"read: token_ids={token_ids} gpu_block_ids={gpu_block_ids} start_read_block_idx={start_read_block_idx} timeout={timeout}" + ) tokens = Tokens(token_ids, self.config.block_token_size) - k_data_ptrs = [k.data_ptr()for k in key_cache] - v_data_ptrs = [v.data_ptr()for v in val_cache] + k_data_ptrs = [k.data_ptr() for k in key_cache] + v_data_ptrs = [v.data_ptr() for v in val_cache] num = 0 try: num = self.sdk.read( @@ -126,18 +134,20 @@ def read( return num def write( - self, - key_cache: List[paddle.Tensor], - val_cache: List[paddle.Tensor], - token_ids: List[int], + self, + key_cache: List[paddle.Tensor], + val_cache: List[paddle.Tensor], + token_ids: List[int], gpu_block_ids: List[int], start_write_block_idx: int, timeout: float = 30.0, ) -> int: - logger.debug(f"write: token_ids={token_ids} gpu_block_ids={gpu_block_ids} start_write_block_idx={start_write_block_idx} timeout={timeout}") + logger.debug( + f"write: token_ids={token_ids} gpu_block_ids={gpu_block_ids} start_write_block_idx={start_write_block_idx} timeout={timeout}" + ) tokens = Tokens(token_ids, self.config.block_token_size) - k_data_ptrs = [k.data_ptr()for k in key_cache] - v_data_ptrs = [v.data_ptr()for v in val_cache] + k_data_ptrs = [k.data_ptr() for k in key_cache] + v_data_ptrs = [v.data_ptr() for v in val_cache] num = 0 try: num = self.sdk.write( @@ -151,7 +161,9 @@ def write( ) logger.debug(f"write: successfully wrote {num} blocks") except AttentionStoreSDKError as e: - logger.error(f"Failed to execute AttentionStoreSDK write, error: {e}, traceback:\n{traceback.format_exc()}") + logger.error( + f"Failed to execute AttentionStoreSDK write, error: {e}, traceback:\n{traceback.format_exc()}" + ) return num def query(self, token_ids: List[int], start_match_block_idx: int, timeout: float = 10.0): @@ -166,26 +178,28 @@ def query(self, token_ids: List[int], start_match_block_idx: int, timeout: float num = self.sdk.match(tokens, start_match_block_idx, timeout) logger.debug(f"query: successfully matched {num} blocks") except AttentionStoreSDKError as e: - logger.error(f"Failed to execute AttentionStoreSDK match, error: {e}, traceback:\n{traceback.format_exc()}") + logger.error( + f"Failed to execute AttentionStoreSDK match, error: {e}, traceback:\n{traceback.format_exc()}" + ) return num def get(self, **kwargs): - raise NotImplementedError(f"AttentionStore does not support this method") + raise NotImplementedError("AttentionStore does not support this method") def batch_get(self, **kwargs): - raise NotImplementedError(f"AttentionStore does not support this method") + raise NotImplementedError("AttentionStore does not support this method") def set(self, **kwargs) -> bool: - raise NotImplementedError(f"AttentionStore does not support this method") + raise NotImplementedError("AttentionStore does not support this method") def batch_set(self, **kwargs) -> bool: - raise NotImplementedError(f"AttentionStore does not support this method") + raise NotImplementedError("AttentionStore does not support this method") def exists(self, keys: List[str]) -> bool: - raise NotImplementedError(f"AttentionStore does not support this method") + raise NotImplementedError("AttentionStore does not support this method") def clear(self) -> bool: - raise NotImplementedError(f"AttentionStore does not support this method") + raise NotImplementedError("AttentionStore does not support this method") def register_buffer(self, buffer_ptr, buffer_size, buffer_type="none_type") -> None: - raise NotImplementedError(f"AttentionStore does not support this method") \ No newline at end of file + raise NotImplementedError("AttentionStore does not support this method") From cb2d51de085a8faba32db470e0bec6faad09df5c Mon Sep 17 00:00:00 2001 From: liyonghua0910 Date: Tue, 30 Dec 2025 10:45:12 +0000 Subject: [PATCH 03/16] [chore] optimize log --- .../cache_manager/cache_transfer_manager.py | 47 ++++++++++++------- .../cache_manager/prefix_cache_manager.py | 6 ++- .../mooncake_store/attention_store.py | 41 +++++++++------- 3 files changed, 58 insertions(+), 36 deletions(-) diff --git a/fastdeploy/cache_manager/cache_transfer_manager.py b/fastdeploy/cache_manager/cache_transfer_manager.py index 00e95aabc9e..575a838ed5b 100644 --- a/fastdeploy/cache_manager/cache_transfer_manager.py +++ b/fastdeploy/cache_manager/cache_transfer_manager.py @@ -464,6 +464,7 @@ def _get_cache_bytes(self, cache_dtype): def _run_read_storage( self, + task_id: str, token_ids: List[int], start_read_block_idx: int, k_cache_keys: List[str], @@ -532,7 +533,7 @@ def _run_read_storage( key_cache.append(self.gpu_cache_kvs[f"key_caches_{i}_rank{self.rank}.device{self.device}"]) val_cache.append(self.gpu_cache_kvs[f"value_caches_{i}_rank{self.rank}.device{self.device}"]) read_block_num = self.storage_backend.read( - key_cache, val_cache, token_ids, gpu_block_ids, start_read_block_idx, timeout + task_id, key_cache, val_cache, token_ids, gpu_block_ids, start_read_block_idx, timeout ) valid_gpu_block_ids = gpu_block_ids[:read_block_num] @@ -556,16 +557,17 @@ def read_storage_task(self, task: ReadStorageTask): if self.storage_backend_type == "mooncake": match_block_num = self.storage_backend.query(k_cache_keys, v_cache_keys) elif self.storage_backend_type == "attention_store": - match_block_num = self.storage_backend.query(task.token_ids, task.start_read_block_idx, task.timeout) - logger.info( - f"Before reading cache from storage, found {match_block_num} blocks already cached in storage for task {task.task_id}" - ) + match_block_num = self.storage_backend.query( + task.task_id, task.token_ids, task.start_read_block_idx, task.timeout + ) + logger.info(f"Matched {match_block_num} blocks in cache storage for read task {task.task_id}") valid_gpu_block_ids = [] if match_block_num > 0: # TODO: support timeout with actual block count try: valid_gpu_block_ids = self._run_read_storage( + task.task_id, task.token_ids, match_block_num, k_cache_keys, @@ -574,7 +576,9 @@ def read_storage_task(self, task: ReadStorageTask): cpu_block_ids, task.timeout, ) - logger.info(f"Successfully read {match_block_num} blocks from storage for task {task.task_id}.") + logger.info( + f"Successfully read {len(valid_gpu_block_ids)} blocks from cache storage for task {task.task_id}" + ) except Exception as e: logger.error(f"Failed to read cache for task {task.task_id}, error: {e}") valid_gpu_block_ids = [] @@ -583,7 +587,7 @@ def read_storage_task(self, task: ReadStorageTask): self.cache_task_queue.swap_storage_to_gpu_barrier.wait() self.cache_task_queue.swap_storage_to_gpu_barrier.reset() self.cache_task_queue.put_transfer_done_signal(result) - logger.debug(f"read_storage_task: put_transfer_done_signal {result}") + logger.debug(f"read_storage_task: put transfer done signal for {task.task_id}") except Exception as e: logger.error( @@ -592,7 +596,15 @@ def read_storage_task(self, task: ReadStorageTask): ) def _run_write_back_storage( - self, token_ids, start_write_block_idx, k_cache_keys, v_cache_keys, gpu_block_ids, cpu_block_ids, timeout + self, + task_id, + token_ids, + start_write_block_idx, + k_cache_keys, + v_cache_keys, + gpu_block_ids, + cpu_block_ids, + timeout, ): try: if self.storage_backend_type == "mooncake": @@ -638,6 +650,7 @@ def _run_write_back_storage( kv_cache_ptrs = k_cache_ptrs + v_cache_ptrs kv_block_sizes = [self.storage_buffer_stride_bytes] * block_num * 2 # key and value self.storage_backend.batch_set(keys, target_locations=kv_cache_ptrs, target_sizes=kv_block_sizes) + return block_num elif self.storage_backend_type == "attention_store": key_cache = [] @@ -645,9 +658,10 @@ def _run_write_back_storage( for i in range(self.num_layers + self.num_extra_layers): key_cache.append(self.gpu_cache_kvs[f"key_caches_{i}_rank{self.rank}.device{self.device}"]) val_cache.append(self.gpu_cache_kvs[f"value_caches_{i}_rank{self.rank}.device{self.device}"]) - self.storage_backend.write( - key_cache, val_cache, token_ids, gpu_block_ids, start_write_block_idx, timeout + write_block_num = self.storage_backend.write( + task_id, key_cache, val_cache, token_ids, gpu_block_ids, start_write_block_idx, timeout ) + return write_block_num except Exception as e: logger.error( @@ -668,10 +682,8 @@ def write_back_storage_task(self, task: WriteStorageTask): if self.storage_backend_type == "mooncake": match_block_num = self.storage_backend.query(k_cache_keys, v_cache_keys, task.timeout) elif self.storage_backend_type == "attention_store": - match_block_num = self.storage_backend.query(task.token_ids, 0, task.timeout) - logger.info( - f"Before writing cache from storage, found {match_block_num} blocks already cached in storage for task {task.task_id}" - ) + match_block_num = self.storage_backend.query(task.task_id, task.token_ids, 0, task.timeout) + logger.info(f"Matched {match_block_num} blocks in cache storage for write task {task.task_id}") if match_block_num >= len(k_cache_keys): logger.info(f"No uncached keys found for task {task.task_id}") @@ -679,7 +691,8 @@ def write_back_storage_task(self, task: WriteStorageTask): else: try: # TODO: support timeout with actual block count - self._run_write_back_storage( + write_block_num = self._run_write_back_storage( + task.task_id, task.token_ids, match_block_num, k_cache_keys, @@ -688,7 +701,9 @@ def write_back_storage_task(self, task: WriteStorageTask): cpu_block_ids, task.timeout, ) - logger.info(f"Successfully wrote cache to storage for task {task.task_id}") + logger.info( + f"Successfully wrote {write_block_num} blocks to cache storage for task {task.task_id}" + ) except Exception as e: logger.error(f"Error in write back storage task: {e}") gpu_block_ids = [] diff --git a/fastdeploy/cache_manager/prefix_cache_manager.py b/fastdeploy/cache_manager/prefix_cache_manager.py index c85aab30e7c..6a1c276d457 100644 --- a/fastdeploy/cache_manager/prefix_cache_manager.py +++ b/fastdeploy/cache_manager/prefix_cache_manager.py @@ -33,6 +33,7 @@ from fastdeploy.cache_manager.cache_metrics import CacheMetrics from fastdeploy.cache_manager.cache_tasks import ReadStorageTask, WriteStorageTask from fastdeploy.cache_manager.ops import get_all_visible_devices +from fastdeploy.config import CacheConfig from fastdeploy.engine.request import Request from fastdeploy.inter_communicator import EngineCacheQueue, IPCSignal, PrefixTreeStatus from fastdeploy.metrics.metrics import main_process_metrics @@ -48,7 +49,7 @@ class PrefixCacheManager: def __init__( self, - config, + config: CacheConfig, tensor_parallel_size, splitwise_role="mixed", local_data_parallel_id=0, @@ -1000,7 +1001,8 @@ def write_cache_to_storage(self, request: Request): task = WriteStorageTask( task_id=req_id, keys=keys, - token_ids=request.prompt_token_ids + request.output_token_ids, + token_ids=request.prompt_token_ids + + (request.output_token_ids if self.config.enable_output_caching else 0), gpu_block_ids=gpu_block_ids, ) tic = time.time() diff --git a/fastdeploy/cache_manager/transfer_factory/mooncake_store/attention_store.py b/fastdeploy/cache_manager/transfer_factory/mooncake_store/attention_store.py index a833f98a869..6f325afba24 100644 --- a/fastdeploy/cache_manager/transfer_factory/mooncake_store/attention_store.py +++ b/fastdeploy/cache_manager/transfer_factory/mooncake_store/attention_store.py @@ -61,7 +61,7 @@ def __init__(self, **args): self.config = AttentionStoreConfig(**args) try: - logger.info(f"Start initializing AttentionStoreSDK with config: {self.config}") + logger.info(f"[INIT] Start initializing AttentionStoreSDK with config: {self.config}") self.sdk = AttentionStoreSDK( self.config.namespace, self.config.pod_name, @@ -75,11 +75,10 @@ def __init__(self, **args): self.config.dp_id, ) self.wait_for_sdk_ready(timeout=300, delta_t=5) - logger.info("✅ AttentionStoreSDK is inititialized successfully!") + logger.info("[INIT] ✅ AttentionStore is initialized successfully!") except Exception as e: logger.error( - f"❌ AttentionStoreSDK initialization failed, error: {e}, traceback: {traceback.format_exc()}" - f"\nconfig: {self.config}" + f"[INIT] ❌ AttentionStore initialization failed, error: {e}, traceback:\n{traceback.format_exc()}" ) def wait_for_sdk_ready(self, timeout: float, delta_t: float): @@ -91,7 +90,7 @@ def wait_for_sdk_ready(self, timeout: float, delta_t: float): return except AttentionStoreSDKError as e: if "cuda memory not ready" in str(e): - logger.debug("wait_for_sdk_ready: cuda memory not ready, try again..") + logger.debug("[INIT] cuda memory not ready, try again..") time.sleep(delta_t) continue else: @@ -104,6 +103,7 @@ def wait_for_sdk_ready(self, timeout: float, delta_t: float): def read( self, + task_id: str, key_cache: List[paddle.Tensor], val_cache: List[paddle.Tensor], token_ids: List[int], @@ -112,7 +112,7 @@ def read( timeout: float = 30.0, ): logger.debug( - f"read: token_ids={token_ids} gpu_block_ids={gpu_block_ids} start_read_block_idx={start_read_block_idx} timeout={timeout}" + f"[READ] task_id: {task_id} token_ids: {token_ids} gpu_block_ids: {gpu_block_ids} start_read_block_idx: {start_read_block_idx} timeout: {timeout}" ) tokens = Tokens(token_ids, self.config.block_token_size) k_data_ptrs = [k.data_ptr() for k in key_cache] @@ -128,13 +128,16 @@ def read( gpu_block_ids, timeout, ) - logger.debug(f"read: successfully read {num} blocks") - except AttentionStoreSDKError as e: - logger.error(f"Failed to execute AttentionStoreSDK read, error: {e}, traceback:\n{traceback.format_exc()}") + logger.debug(f"[READ] task_id: {task_id} read_blocks={num}") + except AttentionStoreSDKError: + logger.error( + f"[READ] failed to execute sdk read, task_id: {task_id}, traceback:\n{traceback.format_exc()}" + ) return num def write( self, + task_id: str, key_cache: List[paddle.Tensor], val_cache: List[paddle.Tensor], token_ids: List[int], @@ -143,7 +146,7 @@ def write( timeout: float = 30.0, ) -> int: logger.debug( - f"write: token_ids={token_ids} gpu_block_ids={gpu_block_ids} start_write_block_idx={start_write_block_idx} timeout={timeout}" + f"[WRITE] task_id: {task_id} token_ids: {token_ids} gpu_block_ids: {gpu_block_ids} start_write_block_idx: {start_write_block_idx} timeout: {timeout}" ) tokens = Tokens(token_ids, self.config.block_token_size) k_data_ptrs = [k.data_ptr() for k in key_cache] @@ -159,27 +162,29 @@ def write( gpu_block_ids, timeout, ) - logger.debug(f"write: successfully wrote {num} blocks") - except AttentionStoreSDKError as e: + logger.debug(f"[WRITE] task_id: {task_id} written_blocks: {num}") + except AttentionStoreSDKError: logger.error( - f"Failed to execute AttentionStoreSDK write, error: {e}, traceback:\n{traceback.format_exc()}" + f"[WRITE] failed to execute sdk write, task_id: {task_id}, traceback:\n{traceback.format_exc()}" ) return num - def query(self, token_ids: List[int], start_match_block_idx: int, timeout: float = 10.0): + def query(self, task_id: str, token_ids: List[int], start_match_block_idx: int, timeout: float = 10.0): """ Given the input ids and starting index to match, get the valid blocks number that can be prefetched from storage backend. """ - logger.debug(f"query: token_ids={token_ids} start_match_block_idx={start_match_block_idx} timeout={timeout}") + logger.debug( + f"[QUERY] task_id: {task_id} token_ids: {token_ids} start_match_block_idx: {start_match_block_idx} timeout: {timeout}" + ) tokens = Tokens(token_ids, self.config.block_token_size) num = 0 try: num = self.sdk.match(tokens, start_match_block_idx, timeout) - logger.debug(f"query: successfully matched {num} blocks") - except AttentionStoreSDKError as e: + logger.debug(f"[QUERY] task_id: {task_id} matched_blocks: {num}") + except AttentionStoreSDKError: logger.error( - f"Failed to execute AttentionStoreSDK match, error: {e}, traceback:\n{traceback.format_exc()}" + f"[QUERY] Failed to execute sdk match, task_id: {task_id}, traceback:\n{traceback.format_exc()}" ) return num From 92a039128858417a08775f0ae338267a12da94a1 Mon Sep 17 00:00:00 2001 From: liyonghua0910 Date: Wed, 31 Dec 2025 04:29:22 +0000 Subject: [PATCH 04/16] [fix] fix write storage task --- .../cache_manager/cache_transfer_manager.py | 1 + .../cache_manager/prefix_cache_manager.py | 15 +++++++++++---- .../mooncake_store/attention_store.py | 18 +++++++++--------- 3 files changed, 21 insertions(+), 13 deletions(-) diff --git a/fastdeploy/cache_manager/cache_transfer_manager.py b/fastdeploy/cache_manager/cache_transfer_manager.py index 575a838ed5b..bf0feda4771 100644 --- a/fastdeploy/cache_manager/cache_transfer_manager.py +++ b/fastdeploy/cache_manager/cache_transfer_manager.py @@ -658,6 +658,7 @@ def _run_write_back_storage( for i in range(self.num_layers + self.num_extra_layers): key_cache.append(self.gpu_cache_kvs[f"key_caches_{i}_rank{self.rank}.device{self.device}"]) val_cache.append(self.gpu_cache_kvs[f"value_caches_{i}_rank{self.rank}.device{self.device}"]) + gpu_block_ids = gpu_block_ids[start_write_block_idx:] write_block_num = self.storage_backend.write( task_id, key_cache, val_cache, token_ids, gpu_block_ids, start_write_block_idx, timeout ) diff --git a/fastdeploy/cache_manager/prefix_cache_manager.py b/fastdeploy/cache_manager/prefix_cache_manager.py index 6a1c276d457..88178052e01 100644 --- a/fastdeploy/cache_manager/prefix_cache_manager.py +++ b/fastdeploy/cache_manager/prefix_cache_manager.py @@ -33,7 +33,7 @@ from fastdeploy.cache_manager.cache_metrics import CacheMetrics from fastdeploy.cache_manager.cache_tasks import ReadStorageTask, WriteStorageTask from fastdeploy.cache_manager.ops import get_all_visible_devices -from fastdeploy.config import CacheConfig +from fastdeploy.config import FDConfig from fastdeploy.engine.request import Request from fastdeploy.inter_communicator import EngineCacheQueue, IPCSignal, PrefixTreeStatus from fastdeploy.metrics.metrics import main_process_metrics @@ -49,7 +49,7 @@ class PrefixCacheManager: def __init__( self, - config: CacheConfig, + config: FDConfig, tensor_parallel_size, splitwise_role="mixed", local_data_parallel_id=0, @@ -814,6 +814,7 @@ def request_match_storage_blocks(self, request, extra_gpu_block_ids): gpu_block_ids=extra_gpu_block_ids, start_read_block_idx=num_cached_tokens // block_size, ) + logger.debug(f"issue read storage task: {task}") matched_block_ids = self.issue_prefetch_storage_task(task, is_sync=True) logger.info( f"finish prefetch cache from storage, req_id: {req_id}, matched block num: {len(matched_block_ids)}" @@ -986,6 +987,12 @@ def write_cache_to_storage(self, request: Request): if self.kvcache_storage_backend is None: return + token_ids = request.prompt_token_ids + if isinstance(token_ids, np.ndarray): + token_ids = token_ids.tolist() + if self.config.cache_config.enable_output_caching: + token_ids += request.output_token_ids + req_id = request.request_id keys = [] node = self.req_leaf_map[req_id] @@ -1001,10 +1008,10 @@ def write_cache_to_storage(self, request: Request): task = WriteStorageTask( task_id=req_id, keys=keys, - token_ids=request.prompt_token_ids - + (request.output_token_ids if self.config.enable_output_caching else 0), + token_ids=token_ids, gpu_block_ids=gpu_block_ids, ) + logger.debug(f"issue write storage task: {task}") tic = time.time() self.issue_write_back_storage_task(task, is_sync=True) cost_time = time.time() - tic diff --git a/fastdeploy/cache_manager/transfer_factory/mooncake_store/attention_store.py b/fastdeploy/cache_manager/transfer_factory/mooncake_store/attention_store.py index 6f325afba24..4960285245a 100644 --- a/fastdeploy/cache_manager/transfer_factory/mooncake_store/attention_store.py +++ b/fastdeploy/cache_manager/transfer_factory/mooncake_store/attention_store.py @@ -112,7 +112,7 @@ def read( timeout: float = 30.0, ): logger.debug( - f"[READ] task_id: {task_id} token_ids: {token_ids} gpu_block_ids: {gpu_block_ids} start_read_block_idx: {start_read_block_idx} timeout: {timeout}" + f"[READ BEGIN] task_id: {task_id} token_ids: {token_ids} gpu_block_ids: {gpu_block_ids} start_read_block_idx: {start_read_block_idx} timeout: {timeout}" ) tokens = Tokens(token_ids, self.config.block_token_size) k_data_ptrs = [k.data_ptr() for k in key_cache] @@ -128,10 +128,10 @@ def read( gpu_block_ids, timeout, ) - logger.debug(f"[READ] task_id: {task_id} read_blocks={num}") + logger.debug(f"[READ END] task_id: {task_id} read_blocks={num}") except AttentionStoreSDKError: logger.error( - f"[READ] failed to execute sdk read, task_id: {task_id}, traceback:\n{traceback.format_exc()}" + f"[READ ERROR] failed to execute sdk read, task_id: {task_id}, traceback:\n{traceback.format_exc()}" ) return num @@ -146,7 +146,7 @@ def write( timeout: float = 30.0, ) -> int: logger.debug( - f"[WRITE] task_id: {task_id} token_ids: {token_ids} gpu_block_ids: {gpu_block_ids} start_write_block_idx: {start_write_block_idx} timeout: {timeout}" + f"[WRITE BEGIN] task_id: {task_id} token_ids: {token_ids} gpu_block_ids: {gpu_block_ids} start_write_block_idx: {start_write_block_idx} timeout: {timeout}" ) tokens = Tokens(token_ids, self.config.block_token_size) k_data_ptrs = [k.data_ptr() for k in key_cache] @@ -162,10 +162,10 @@ def write( gpu_block_ids, timeout, ) - logger.debug(f"[WRITE] task_id: {task_id} written_blocks: {num}") + logger.debug(f"[WRITE END] task_id: {task_id} written_blocks: {num}") except AttentionStoreSDKError: logger.error( - f"[WRITE] failed to execute sdk write, task_id: {task_id}, traceback:\n{traceback.format_exc()}" + f"[WRITE ERROR] failed to execute sdk write, task_id: {task_id}, traceback:\n{traceback.format_exc()}" ) return num @@ -175,16 +175,16 @@ def query(self, task_id: str, token_ids: List[int], start_match_block_idx: int, can be prefetched from storage backend. """ logger.debug( - f"[QUERY] task_id: {task_id} token_ids: {token_ids} start_match_block_idx: {start_match_block_idx} timeout: {timeout}" + f"[QUERY BEGIN] task_id: {task_id} token_ids: {token_ids} start_match_block_idx: {start_match_block_idx} timeout: {timeout}" ) tokens = Tokens(token_ids, self.config.block_token_size) num = 0 try: num = self.sdk.match(tokens, start_match_block_idx, timeout) - logger.debug(f"[QUERY] task_id: {task_id} matched_blocks: {num}") + logger.debug(f"[QUERY END] task_id: {task_id} matched_blocks: {num}") except AttentionStoreSDKError: logger.error( - f"[QUERY] Failed to execute sdk match, task_id: {task_id}, traceback:\n{traceback.format_exc()}" + f"[QUERY ERROR] Failed to execute sdk match, task_id: {task_id}, traceback:\n{traceback.format_exc()}" ) return num From a4c9904fc0266521e5ecff5a6e62f371f27ec22e Mon Sep 17 00:00:00 2001 From: liyonghua0910 Date: Sun, 4 Jan 2026 09:42:33 +0000 Subject: [PATCH 05/16] [fix] fix read storage --- docs/zh/online_serving/metrics.md | 2 +- .../cache_manager/cache_transfer_manager.py | 22 ++++++++----------- .../engine/sched/resource_manager_v1.py | 2 +- 3 files changed, 11 insertions(+), 15 deletions(-) diff --git a/docs/zh/online_serving/metrics.md b/docs/zh/online_serving/metrics.md index 75576995ebb..fe559a27337 100644 --- a/docs/zh/online_serving/metrics.md +++ b/docs/zh/online_serving/metrics.md @@ -32,7 +32,7 @@ | KV缓存 | `fastdeploy:gpu_hit_token_rate` | Gauge | token 级别 GPU 前缀缓存命中率 | 百分比 | | KV缓存 | `fastdeploy:prefix_cache_token_num` | Counter | 前缀缓存token总数 | 个 | | KV缓存 | `fastdeploy:prefix_gpu_cache_token_num` | Counter | 位于 GPU 上的前缀缓存 token 总数 | 个 | -| KV缓存 | `fastdeploy:prefix_cpu_cache_token_num` | Counter | 位于 GPU 上的前缀缓存 token 总数 | 个 | +| KV缓存 | `fastdeploy:prefix_cpu_cache_token_num` | Counter | 位于 CPU 上的前缀缓存 token 总数 | 个 | | KV缓存 | `fastdeploy:available_gpu_block_num` | Gauge | 缓存中可用的 GPU 块数量(包含尚未正式释放的前缀缓存块)| 个 | | KV缓存 | `fastdeploy:free_gpu_block_num` | Gauge | 缓存中的可用块数 | 个 | | KV缓存 | `fastdeploy:max_gpu_block_num` | Gauge | 服务启动时确定的总块数 | 个 | diff --git a/fastdeploy/cache_manager/cache_transfer_manager.py b/fastdeploy/cache_manager/cache_transfer_manager.py index bf0feda4771..f029d3be0ac 100644 --- a/fastdeploy/cache_manager/cache_transfer_manager.py +++ b/fastdeploy/cache_manager/cache_transfer_manager.py @@ -478,11 +478,6 @@ def _run_read_storage( """ try: if self.storage_backend_type == "mooncake": - k_cache_keys = k_cache_keys[:start_read_block_idx] - v_cache_keys = v_cache_keys[:start_read_block_idx] - gpu_block_ids = gpu_block_ids[:start_read_block_idx] - cpu_block_ids = cpu_block_ids[:start_read_block_idx] - block_num = len(gpu_block_ids) keys = k_cache_keys + v_cache_keys k_cache_ptrs = [ @@ -552,7 +547,6 @@ def read_storage_task(self, task: ReadStorageTask): cpu_block_ids = [i for i in range(len(gpu_block_ids))] k_cache_keys = [f"{key}_key_{self.rank}" for key in task.keys] v_cache_keys = [f"{key}_value_{self.rank}" for key in task.keys] - match_block_num = 0 if self.storage_backend_type == "mooncake": match_block_num = self.storage_backend.query(k_cache_keys, v_cache_keys) @@ -562,6 +556,10 @@ def read_storage_task(self, task: ReadStorageTask): ) logger.info(f"Matched {match_block_num} blocks in cache storage for read task {task.task_id}") + k_cache_keys = k_cache_keys[:match_block_num] + v_cache_keys = v_cache_keys[:match_block_num] + gpu_block_ids = gpu_block_ids[:match_block_num] + cpu_block_ids = cpu_block_ids[:match_block_num] valid_gpu_block_ids = [] if match_block_num > 0: # TODO: support timeout with actual block count @@ -569,7 +567,7 @@ def read_storage_task(self, task: ReadStorageTask): valid_gpu_block_ids = self._run_read_storage( task.task_id, task.token_ids, - match_block_num, + task.start_read_block_idx + match_block_num, k_cache_keys, v_cache_keys, gpu_block_ids, @@ -608,11 +606,6 @@ def _run_write_back_storage( ): try: if self.storage_backend_type == "mooncake": - k_cache_keys = k_cache_keys[start_write_block_idx:] - v_cache_keys = v_cache_keys[start_write_block_idx:] - gpu_block_ids = gpu_block_ids[start_write_block_idx:] - cpu_block_ids = cpu_block_ids[start_write_block_idx:] - key_cache_size = [ self.key_cache_shape[0], self.key_cache_shape[1], @@ -658,7 +651,6 @@ def _run_write_back_storage( for i in range(self.num_layers + self.num_extra_layers): key_cache.append(self.gpu_cache_kvs[f"key_caches_{i}_rank{self.rank}.device{self.device}"]) val_cache.append(self.gpu_cache_kvs[f"value_caches_{i}_rank{self.rank}.device{self.device}"]) - gpu_block_ids = gpu_block_ids[start_write_block_idx:] write_block_num = self.storage_backend.write( task_id, key_cache, val_cache, token_ids, gpu_block_ids, start_write_block_idx, timeout ) @@ -686,6 +678,10 @@ def write_back_storage_task(self, task: WriteStorageTask): match_block_num = self.storage_backend.query(task.task_id, task.token_ids, 0, task.timeout) logger.info(f"Matched {match_block_num} blocks in cache storage for write task {task.task_id}") + k_cache_keys = k_cache_keys[match_block_num:] + v_cache_keys = v_cache_keys[match_block_num:] + gpu_block_ids = gpu_block_ids[match_block_num:] + cpu_block_ids = cpu_block_ids[match_block_num:] if match_block_num >= len(k_cache_keys): logger.info(f"No uncached keys found for task {task.task_id}") gpu_block_ids = [] diff --git a/fastdeploy/engine/sched/resource_manager_v1.py b/fastdeploy/engine/sched/resource_manager_v1.py index 777b559748c..f8550af5f48 100644 --- a/fastdeploy/engine/sched/resource_manager_v1.py +++ b/fastdeploy/engine/sched/resource_manager_v1.py @@ -940,7 +940,7 @@ def get_prefix_cached_blocks(self, request: Request): # Report the number of cached tokens to Prometheus metrics main_process_metrics.prefix_cache_token_num.inc(matched_token_num) main_process_metrics.prefix_gpu_cache_token_num.inc(request.metrics.gpu_cache_token_num) - main_process_metrics.prefix_cpu_cache_token_num.inc(request.metrics.gpu_cache_token_num) + main_process_metrics.prefix_cpu_cache_token_num.inc(request.metrics.cpu_cache_token_num) if matched_token_num == request.need_prefill_tokens: request.num_computed_tokens = matched_token_num - self.config.cache_config.block_size From ad841d176a231d9e8b3b918963f0a4076ee0a6d2 Mon Sep 17 00:00:00 2001 From: liyonghua0910 Date: Mon, 5 Jan 2026 08:23:25 +0000 Subject: [PATCH 06/16] [fix] fix code conflict after merge develop --- fastdeploy/cache_manager/prefix_cache_manager.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/fastdeploy/cache_manager/prefix_cache_manager.py b/fastdeploy/cache_manager/prefix_cache_manager.py index 1185961d076..9775131280c 100644 --- a/fastdeploy/cache_manager/prefix_cache_manager.py +++ b/fastdeploy/cache_manager/prefix_cache_manager.py @@ -783,9 +783,15 @@ def request_match_blocks(self, task: Request, block_size, *args): f"start prefetch cache from storage, req_id: {req_id}, block num: {len(no_match_block_keys)}" ) start_time = time.time() - storage_matched_block_ids = self.issue_prefetch_storage_task( - req_id, no_match_block_keys, gpu_recv_storage_block_ids + read_storage_task = ReadStorageTask( + task_id=req_id, + keys=no_match_block_keys, + token_ids=input_token_ids, + gpu_block_ids=gpu_recv_storage_block_ids, + start_read_block_idx=match_token_num // block_size, ) + logger.debug(f"issue read storage task: {read_storage_task}") + storage_matched_block_ids = self.issue_prefetch_storage_task(read_storage_task) storage_matched_block_num = len(storage_matched_block_ids) storage_match_token_num = storage_matched_block_num * block_size cost_time = time.time() - start_time @@ -1015,15 +1021,15 @@ def write_cache_to_storage(self, request: Request): gpu_block_ids = request.block_tables[: len(keys)] logger.info(f"start write cache back to storage, req_id: {req_id}, block num: {len(keys)}") - task = WriteStorageTask( + write_storage_task = WriteStorageTask( task_id=req_id, keys=keys, token_ids=token_ids, gpu_block_ids=gpu_block_ids, ) - logger.debug(f"issue write storage task: {task}") + logger.debug(f"issue write storage task: {write_storage_task}") tic = time.time() - self.issue_write_back_storage_task(task, is_sync=True) + self.issue_write_back_storage_task(write_storage_task, is_sync=True) cost_time = time.time() - tic logger.info(f"finish write cache back to storage, req_id: {req_id}, cost_time: {cost_time:.6f}s") From 4af8c668d8322f2bdd027b084a3743e8c9eed04b Mon Sep 17 00:00:00 2001 From: liyonghua0910 Date: Mon, 5 Jan 2026 11:31:42 +0000 Subject: [PATCH 07/16] [fix] fix cache bytes and read task token ids --- .../cache_manager/cache_transfer_manager.py | 24 +++++++++---------- .../mooncake_store/attention_store.py | 2 +- 2 files changed, 13 insertions(+), 13 deletions(-) diff --git a/fastdeploy/cache_manager/cache_transfer_manager.py b/fastdeploy/cache_manager/cache_transfer_manager.py index 5f6a815cefd..ca33d0929e0 100644 --- a/fastdeploy/cache_manager/cache_transfer_manager.py +++ b/fastdeploy/cache_manager/cache_transfer_manager.py @@ -22,7 +22,6 @@ import threading import time import traceback -from math import prod from typing import List import numpy as np @@ -155,13 +154,16 @@ def __init__(self, args): self.block_size = self.key_cache_shape[2] self.head_dim = self.key_cache_shape[3] + # compute cache bytes + self.cache_dtype = args.cache_dtype + self.cache_bytes = self._get_cache_bytes(self.cache_dtype) + # extract other arg values self.n_ranks = args.mp_num self.rank = args.rank self.device = args.device_id self.num_layers = args.num_layers self.ipc_suffix = args.ipc_suffix - self.cache_dtype = args.cache_dtype self.local_data_parallel_id = args.local_data_parallel_id self.num_extra_layers = self.speculative_config.num_extra_cache_layer self.num_extra_layer_gpu_blocks = int(self.num_gpu_blocks * self.speculative_config.num_gpu_block_expand_ratio) @@ -241,7 +243,7 @@ def __init__(self, args): shard_num=self.n_ranks, layer_num=self.num_layers + self.num_extra_layers, block_token_size=self.block_size, - bytes_per_shard_layer_per_block=prod(self.key_cache_shape[1:]), + bytes_per_shard_layer_per_block=self.head_num * self.block_size * self.head_dim * self.cache_bytes, device_id=self.device, dp_id=self.local_data_parallel_id, ) @@ -263,17 +265,15 @@ def _init_storage_buffer(self, args): buffer layout: [block_num, layer_num, head_num, block_size, head_dim] """ layer_num = self.num_layers + self.num_extra_layers - head_num = self.key_cache_shape[1] - block_size = self.key_cache_shape[2] - head_dim = self.key_cache_shape[3] - block_num = (args.max_model_len + block_size - 1) // block_size + block_num = (args.max_model_len + self.block_size - 1) // self.block_size logger.info( f"Creating cache buffer for storage with shape: " - f"[{block_num}, {layer_num}, {head_num}, {block_size}, {head_dim}]" + f"[{block_num}, {layer_num}, {self.head_num}, {self.block_size}, {self.head_dim}]" ) - self.cache_bytes = self._get_cache_bytes(self.cache_dtype) - self.storage_buffer_stride_bytes = layer_num * head_num * block_size * head_dim * self.cache_bytes + self.storage_buffer_stride_bytes = ( + layer_num * self.head_num * self.block_size * self.head_dim * self.cache_bytes + ) total_bytes = block_num * self.storage_buffer_stride_bytes * 2 # key and value logger.info(f"Creating cpu buffer cache for alllayers: {total_bytes / 1024 ** 3:.2f}GB") @@ -566,8 +566,8 @@ def read_storage_task(self, task: ReadStorageTask): try: valid_gpu_block_ids = self._run_read_storage( task.task_id, - task.token_ids, - task.start_read_block_idx + match_block_num, + task.token_ids[: match_block_num * self.block_size], + task.start_read_block_idx, k_cache_keys, v_cache_keys, gpu_block_ids, diff --git a/fastdeploy/cache_manager/transfer_factory/mooncake_store/attention_store.py b/fastdeploy/cache_manager/transfer_factory/mooncake_store/attention_store.py index 4960285245a..5cbd39694fc 100644 --- a/fastdeploy/cache_manager/transfer_factory/mooncake_store/attention_store.py +++ b/fastdeploy/cache_manager/transfer_factory/mooncake_store/attention_store.py @@ -128,7 +128,7 @@ def read( gpu_block_ids, timeout, ) - logger.debug(f"[READ END] task_id: {task_id} read_blocks={num}") + logger.debug(f"[READ END] task_id: {task_id} read_blocks: {num}") except AttentionStoreSDKError: logger.error( f"[READ ERROR] failed to execute sdk read, task_id: {task_id}, traceback:\n{traceback.format_exc()}" From fda45979d3a76f7d18dffc7e1d9c0bd9c60b7413 Mon Sep 17 00:00:00 2001 From: liyonghua0910 Date: Tue, 6 Jan 2026 06:55:36 +0000 Subject: [PATCH 08/16] [chore] add model for cache transfer manager --- fastdeploy/cache_manager/cache_transfer_manager.py | 3 +++ fastdeploy/cache_manager/prefix_cache_manager.py | 1 + 2 files changed, 4 insertions(+) diff --git a/fastdeploy/cache_manager/cache_transfer_manager.py b/fastdeploy/cache_manager/cache_transfer_manager.py index ca33d0929e0..a97c6d907da 100644 --- a/fastdeploy/cache_manager/cache_transfer_manager.py +++ b/fastdeploy/cache_manager/cache_transfer_manager.py @@ -59,6 +59,7 @@ def parse_args(): default="mixed", help="splitwise role, can be decode, prefill or mixed", ) + parser.add_argument("--model_id", type=str, default="default", help="model id") parser.add_argument("--rank", type=int, default=0, help="local tp rank") parser.add_argument("--device_id", type=int, default=0, help="device id") parser.add_argument("--max_model_len", type=int, default=32768, help="max model length") @@ -159,6 +160,7 @@ def __init__(self, args): self.cache_bytes = self._get_cache_bytes(self.cache_dtype) # extract other arg values + self.model_id = args.model_id self.n_ranks = args.mp_num self.rank = args.rank self.device = args.device_id @@ -239,6 +241,7 @@ def __init__(self, args): elif args.kvcache_storage_backend == "attention_store": logger.info("Start initialize attention store...") self.storage_backend = AttentionStore( + namespace=self.model_id, shard_id=self.rank, shard_num=self.n_ranks, layer_num=self.num_layers + self.num_extra_layers, diff --git a/fastdeploy/cache_manager/prefix_cache_manager.py b/fastdeploy/cache_manager/prefix_cache_manager.py index 9775131280c..52b8e95065b 100644 --- a/fastdeploy/cache_manager/prefix_cache_manager.py +++ b/fastdeploy/cache_manager/prefix_cache_manager.py @@ -273,6 +273,7 @@ def launch_cache_manager( + " NCCL_MAX_NCHANNELS=1 NCCL_BUFFSIZE=0" + f" FD_ENABLE_SWAP_SPACE_CLEARING={envs.FD_ENABLE_SWAP_SPACE_CLEARING}" + f" {sys.executable} {py_path}" + + f" --model_id {os.path.basename(self.config.model_config.model)}" + f" --device_id {int(device_ids[i])}" + f" --rank {i}" + f" --splitwise_role {self.splitwise_role}" From 34bdd428de5f2b33199933972a3d8c5677c700e3 Mon Sep 17 00:00:00 2001 From: liyonghua0910 Date: Wed, 7 Jan 2026 13:58:47 +0000 Subject: [PATCH 09/16] [chore] add some log --- fastdeploy/cache_manager/cache_messager.py | 2 +- fastdeploy/cache_manager/cache_tasks.py | 16 +++++++++++ .../cache_manager/cache_transfer_manager.py | 28 ++++++++++++++++++- 3 files changed, 44 insertions(+), 2 deletions(-) diff --git a/fastdeploy/cache_manager/cache_messager.py b/fastdeploy/cache_manager/cache_messager.py index ab18ed30f4f..91d79a7a3a4 100644 --- a/fastdeploy/cache_manager/cache_messager.py +++ b/fastdeploy/cache_manager/cache_messager.py @@ -1028,7 +1028,7 @@ def main(): args = parse_args() rank_id = args.rank + args.local_data_parallel_id * args.mp_num if args.mp_num > 1: - logger = get_logger("cache_messager", f"cache_messager.log.{rank_id}") + logger = get_logger("cache_messager", f"cache_messager_{rank_id}.log") else: logger = get_logger("cache_messager", "cache_messager.log") diff --git a/fastdeploy/cache_manager/cache_tasks.py b/fastdeploy/cache_manager/cache_tasks.py index 43ffb6a682f..fe15263827a 100644 --- a/fastdeploy/cache_manager/cache_tasks.py +++ b/fastdeploy/cache_manager/cache_tasks.py @@ -1,3 +1,19 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + from dataclasses import dataclass from typing import List diff --git a/fastdeploy/cache_manager/cache_transfer_manager.py b/fastdeploy/cache_manager/cache_transfer_manager.py index a97c6d907da..14d8b9963b5 100644 --- a/fastdeploy/cache_manager/cache_transfer_manager.py +++ b/fastdeploy/cache_manager/cache_transfer_manager.py @@ -491,9 +491,11 @@ def _run_read_storage( ] kv_cache_ptrs = k_cache_ptrs + v_cache_ptrs kv_block_sizes = [self.storage_buffer_stride_bytes] * block_num * 2 # key and value + start_time = time.time() result = self.storage_backend.batch_get( keys, target_locations=kv_cache_ptrs, target_sizes=kv_block_sizes ) + read_cost_time = time.time() - start_time k_result, v_result = result[:block_num], result[block_num:] success_block_num = 0 @@ -505,6 +507,7 @@ def _run_read_storage( valid_cpu_block_ids = cpu_block_ids[:success_block_num] mode = 1 # cpu ==> gpu + start_time = time.time() swap_cache_layout( self.gpu_cache_k_tensors, self.storage_key_read_buffer, @@ -523,6 +526,10 @@ def _run_read_storage( self.device, mode, ) + swap_cost_time = time.time() - start_time + logger.debug( + f"_run_read_storage, swap_cost_time: {swap_cost_time:.6f}s, read_cost_time: {read_cost_time:.6f}s" + ) elif self.storage_backend_type == "attention_store": key_cache = [] @@ -530,10 +537,16 @@ def _run_read_storage( for i in range(self.num_layers + self.num_extra_layers): key_cache.append(self.gpu_cache_kvs[f"key_caches_{i}_rank{self.rank}.device{self.device}"]) val_cache.append(self.gpu_cache_kvs[f"value_caches_{i}_rank{self.rank}.device{self.device}"]) + + start_time = time.time() read_block_num = self.storage_backend.read( task_id, key_cache, val_cache, token_ids, gpu_block_ids, start_read_block_idx, timeout ) + read_cost_time = time.time() - start_time valid_gpu_block_ids = gpu_block_ids[:read_block_num] + logger.debug( + f"_run_read_storage, swap_cost_time: {swap_cost_time:.6f}s, read_cost_time: {read_cost_time:.6f}s" + ) return valid_gpu_block_ids @@ -616,6 +629,7 @@ def _run_write_back_storage( self.key_cache_shape[3], ] mode = 0 # gpu ==> cpu + start_time = time.time() swap_cache_layout( self.gpu_cache_k_tensors, self.storage_key_write_buffer, @@ -634,6 +648,7 @@ def _run_write_back_storage( self.device, mode, ) + swap_cost_time = time.time() - start_time block_num = len(gpu_block_ids) keys = k_cache_keys + v_cache_keys @@ -645,7 +660,14 @@ def _run_write_back_storage( ] kv_cache_ptrs = k_cache_ptrs + v_cache_ptrs kv_block_sizes = [self.storage_buffer_stride_bytes] * block_num * 2 # key and value + + start_time = time.time() self.storage_backend.batch_set(keys, target_locations=kv_cache_ptrs, target_sizes=kv_block_sizes) + write_cost_time = time.time() - start_time + + logger.debug( + f"_run_write_back_storage, swap_cost_time: {swap_cost_time:.6f}s, write_cost_time: {write_cost_time:.6f}s" + ) return block_num elif self.storage_backend_type == "attention_store": @@ -654,9 +676,13 @@ def _run_write_back_storage( for i in range(self.num_layers + self.num_extra_layers): key_cache.append(self.gpu_cache_kvs[f"key_caches_{i}_rank{self.rank}.device{self.device}"]) val_cache.append(self.gpu_cache_kvs[f"value_caches_{i}_rank{self.rank}.device{self.device}"]) + + start_time = time.time() write_block_num = self.storage_backend.write( task_id, key_cache, val_cache, token_ids, gpu_block_ids, start_write_block_idx, timeout ) + write_cost_time = time.time() - start_time + logger.debug(f"_run_write_back_storage, write_cost_time: {write_cost_time:.6f}s") return write_block_num except Exception as e: @@ -1107,7 +1133,7 @@ def main(): args = parse_args() rank_id = args.rank + args.local_data_parallel_id * args.mp_num if args.mp_num > 1: - logger = get_logger("cache_transfer", f"cache_transfer.log.{rank_id}") + logger = get_logger("cache_transfer", f"cache_transfer_{rank_id}.log") else: logger = get_logger("cache_transfer", "cache_transfer.log") From 030ce3214145b2559960b687da90137e41b6165d Mon Sep 17 00:00:00 2001 From: liyonghua0910 Date: Thu, 8 Jan 2026 04:20:07 +0000 Subject: [PATCH 10/16] [chore] remove launched_cache_manager_signal --- fastdeploy/cache_manager/prefix_cache_manager.py | 2 +- fastdeploy/engine/common_engine.py | 4 ---- fastdeploy/engine/engine.py | 4 ---- 3 files changed, 1 insertion(+), 9 deletions(-) diff --git a/fastdeploy/cache_manager/prefix_cache_manager.py b/fastdeploy/cache_manager/prefix_cache_manager.py index 3fd057b9953..945dbf41cb8 100644 --- a/fastdeploy/cache_manager/prefix_cache_manager.py +++ b/fastdeploy/cache_manager/prefix_cache_manager.py @@ -392,7 +392,7 @@ def launch_cache_messager( + f" --ipc_suffix {ipc_suffix}" + f" --rdma_port {cache_config.local_rdma_comm_ports[i] if cache_config.local_rdma_comm_ports is not None else '0'}" + f" --speculative_config '{self.speculative_config.to_json_string()}'" - + f" >{log_dir}/launch_cache_messager_tprank{i}.log 2>&1" + + f" >{log_dir}/launch_cache_messager_{i}.log 2>&1" ) logger.info(f"Launch cache messager, command:{launch_cmd}") cache_messager_processes.append(subprocess.Popen(launch_cmd, shell=True, preexec_fn=os.setsid)) diff --git a/fastdeploy/engine/common_engine.py b/fastdeploy/engine/common_engine.py index cf0908fe6d2..2261b22a8df 100644 --- a/fastdeploy/engine/common_engine.py +++ b/fastdeploy/engine/common_engine.py @@ -225,10 +225,6 @@ def check_worker_initialize_status_func(res: dict): device_ids = self.cfg.parallel_config.device_ids.split(",") self.cache_manager_processes = self.start_cache_service(device_ids, self.ipc_signal_suffix) - # Set cache manager signal - if self.cfg.scheduler_config.splitwise_role != "mixed": - self.launched_cache_manager_signal.value[0] = 1 - # Worker launched self.check_worker_initialize_status_func_thread.join() if not result_container["worker_is_alive"]: diff --git a/fastdeploy/engine/engine.py b/fastdeploy/engine/engine.py index dc1f1a451fb..942224e32be 100644 --- a/fastdeploy/engine/engine.py +++ b/fastdeploy/engine/engine.py @@ -181,10 +181,6 @@ def check_worker_initialize_status_func(res: dict): device_ids = self.cfg.parallel_config.device_ids.split(",") self.cache_manager_processes = self.engine.start_cache_service(device_ids, self.ipc_signal_suffix) - # Launch components: scheduler, cache_manager, expert_service et.al. - if self.cfg.scheduler_config.splitwise_role != "mixed": - self.launched_cache_manager_signal.value[0] = 1 - if self.cfg.scheduler_config.splitwise_role != "mixed" and envs.FD_ENABLE_INTERNAL_ADAPTER: envs.FD_ZMQ_RECV_REQUEST_SERVER_PORT = envs.FD_ZMQ_RECV_REQUEST_SERVER_PORTS.split(",")[0] envs.FD_ZMQ_SEND_RESPONSE_SERVER_PORT = envs.FD_ZMQ_SEND_RESPONSE_SERVER_PORTS.split(",")[0] From 48617d1aaf8eb67d40743abf2488e9b0d178d90d Mon Sep 17 00:00:00 2001 From: liyonghua0910 Date: Fri, 9 Jan 2026 05:54:28 +0000 Subject: [PATCH 11/16] [fix] fix write_back_storage_task match_block_num condition --- fastdeploy/cache_manager/cache_transfer_manager.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/fastdeploy/cache_manager/cache_transfer_manager.py b/fastdeploy/cache_manager/cache_transfer_manager.py index 14d8b9963b5..6d3227ade97 100644 --- a/fastdeploy/cache_manager/cache_transfer_manager.py +++ b/fastdeploy/cache_manager/cache_transfer_manager.py @@ -707,15 +707,15 @@ def write_back_storage_task(self, task: WriteStorageTask): match_block_num = self.storage_backend.query(task.task_id, task.token_ids, 0, task.timeout) logger.info(f"Matched {match_block_num} blocks in cache storage for write task {task.task_id}") - k_cache_keys = k_cache_keys[match_block_num:] - v_cache_keys = v_cache_keys[match_block_num:] - gpu_block_ids = gpu_block_ids[match_block_num:] - cpu_block_ids = cpu_block_ids[match_block_num:] if match_block_num >= len(k_cache_keys): logger.info(f"No uncached keys found for task {task.task_id}") gpu_block_ids = [] else: try: + k_cache_keys = k_cache_keys[match_block_num:] + v_cache_keys = v_cache_keys[match_block_num:] + gpu_block_ids = gpu_block_ids[match_block_num:] + cpu_block_ids = cpu_block_ids[match_block_num:] # TODO: support timeout with actual block count write_block_num = self._run_write_back_storage( task.task_id, From 3d6bd87eb5640d15c428a826b146665f0261dbb6 Mon Sep 17 00:00:00 2001 From: liyonghua0910 Date: Mon, 19 Jan 2026 06:33:59 +0000 Subject: [PATCH 12/16] [fix] fix swap_cost_time --- fastdeploy/cache_manager/cache_transfer_manager.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/fastdeploy/cache_manager/cache_transfer_manager.py b/fastdeploy/cache_manager/cache_transfer_manager.py index 66888a111e8..7fc31ea637a 100644 --- a/fastdeploy/cache_manager/cache_transfer_manager.py +++ b/fastdeploy/cache_manager/cache_transfer_manager.py @@ -552,9 +552,7 @@ def _run_read_storage( ) read_cost_time = time.time() - start_time valid_gpu_block_ids = gpu_block_ids[:read_block_num] - logger.debug( - f"_run_read_storage, swap_cost_time: {swap_cost_time:.6f}s, read_cost_time: {read_cost_time:.6f}s" - ) + logger.debug(f"_run_read_storage, read_cost_time: {read_cost_time:.6f}s") return valid_gpu_block_ids From 9625602b2ad0d7ce78f01491e23933970464c1c6 Mon Sep 17 00:00:00 2001 From: liyonghua0910 Date: Thu, 22 Jan 2026 03:11:41 +0000 Subject: [PATCH 13/16] [ci] fix ci --- tests/cache_manager/test_cache_transfer_manager.py | 1 + tests/cache_manager/test_prefix_cache_manager.py | 1 + tests/engine/test_common_engine.py | 2 -- 3 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/cache_manager/test_cache_transfer_manager.py b/tests/cache_manager/test_cache_transfer_manager.py index 92fcfe617a3..b27fac3bd63 100644 --- a/tests/cache_manager/test_cache_transfer_manager.py +++ b/tests/cache_manager/test_cache_transfer_manager.py @@ -29,6 +29,7 @@ class Args: mp_num = 1 device_id = 0 speculative_config = {} + model_id = "test_model" ipc_suffix = "test_ipc_suffix" cache_queue_port = 9999 pod_ip = "127.0.0.1" diff --git a/tests/cache_manager/test_prefix_cache_manager.py b/tests/cache_manager/test_prefix_cache_manager.py index 52d1c01040e..1045f860a0c 100644 --- a/tests/cache_manager/test_prefix_cache_manager.py +++ b/tests/cache_manager/test_prefix_cache_manager.py @@ -185,6 +185,7 @@ def _create_manager( swap_space=4, ) model_config = SimpleNamespace( + model="test_model", num_attention_heads=1, num_key_value_heads=1, head_dim=1, diff --git a/tests/engine/test_common_engine.py b/tests/engine/test_common_engine.py index 8fb79b1c801..100069fefa3 100644 --- a/tests/engine/test_common_engine.py +++ b/tests/engine/test_common_engine.py @@ -332,8 +332,6 @@ def fake_init_signals(): self.assertFalse(ok) # cache manager started before workers (lines 184-185) self.assertTrue(started_cache.get("called", False)) - # launched_cache_manager_signal set (line 221) - self.assertEqual(int(eng.launched_cache_manager_signal.value[0]), 1) # avoid atexit finalizer if hasattr(eng, "_finalizer"): try: From 117d28aeb08052f752e1d9c754a57e5d03677f02 Mon Sep 17 00:00:00 2001 From: Yonghua Li <39643373+liyonghua0910@users.noreply.github.com> Date: Thu, 22 Jan 2026 15:53:20 +0800 Subject: [PATCH 14/16] Update fastdeploy/engine/sched/resource_manager_v1.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- fastdeploy/engine/sched/resource_manager_v1.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fastdeploy/engine/sched/resource_manager_v1.py b/fastdeploy/engine/sched/resource_manager_v1.py index cb88dc42ff3..83396ddf753 100644 --- a/fastdeploy/engine/sched/resource_manager_v1.py +++ b/fastdeploy/engine/sched/resource_manager_v1.py @@ -1341,7 +1341,7 @@ def update_metrics(self): blocks_used_by_tasks = set() for task in self.tasks_list: if task is not None: - blocks_used_by_tasks.union(task.block_tables) + blocks_used_by_tasks.update(task.block_tables) main_process_metrics.available_gpu_block_num.set(self.total_block_number() - len(blocks_used_by_tasks)) main_process_metrics.batch_size.set(self.max_num_seqs - self.available_batch()) main_process_metrics.gpu_cache_usage_perc.set(self.get_gpu_cache_usage_perc()) From 9df86a78a5c1505448fd9433290aca432c30a9b9 Mon Sep 17 00:00:00 2001 From: Yonghua Li <39643373+liyonghua0910@users.noreply.github.com> Date: Thu, 22 Jan 2026 15:54:24 +0800 Subject: [PATCH 15/16] Update fastdeploy/cache_manager/cache_transfer_manager.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- fastdeploy/cache_manager/cache_transfer_manager.py | 1 + 1 file changed, 1 insertion(+) diff --git a/fastdeploy/cache_manager/cache_transfer_manager.py b/fastdeploy/cache_manager/cache_transfer_manager.py index 7fc31ea637a..5da4aa17027 100644 --- a/fastdeploy/cache_manager/cache_transfer_manager.py +++ b/fastdeploy/cache_manager/cache_transfer_manager.py @@ -695,6 +695,7 @@ def _run_write_back_storage( logger.error( f"An error occurred in _run_write_back_storage, " f"error: {e}, traceback:\n{traceback.format_exc()}" ) + return 0 def write_back_storage_task(self, task: WriteStorageTask): """ From 2b4865901295696971fe642bac2b80ab3f900f97 Mon Sep 17 00:00:00 2001 From: Yonghua Li <39643373+liyonghua0910@users.noreply.github.com> Date: Thu, 22 Jan 2026 18:00:24 +0800 Subject: [PATCH 16/16] Update fastdeploy/cache_manager/transfer_factory/mooncake_store/attention_store.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- .../transfer_factory/mooncake_store/attention_store.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/fastdeploy/cache_manager/transfer_factory/mooncake_store/attention_store.py b/fastdeploy/cache_manager/transfer_factory/mooncake_store/attention_store.py index 5cbd39694fc..c67ac22574d 100644 --- a/fastdeploy/cache_manager/transfer_factory/mooncake_store/attention_store.py +++ b/fastdeploy/cache_manager/transfer_factory/mooncake_store/attention_store.py @@ -92,13 +92,12 @@ def wait_for_sdk_ready(self, timeout: float, delta_t: float): if "cuda memory not ready" in str(e): logger.debug("[INIT] cuda memory not ready, try again..") time.sleep(delta_t) + t += delta_t continue else: raise RuntimeError( f"Unexpected exception during AttentionStoreSDK initialization: {e}\n{traceback.format_exc()}" ) - finally: - t += delta_t raise TimeoutError(f"AttentionStoreSDK initialization timed out after {timeout} seconds") def read(