diff --git a/docs/zh/online_serving/metrics.md b/docs/zh/online_serving/metrics.md index 75576995ebb..fe559a27337 100644 --- a/docs/zh/online_serving/metrics.md +++ b/docs/zh/online_serving/metrics.md @@ -32,7 +32,7 @@ | KV缓存 | `fastdeploy:gpu_hit_token_rate` | Gauge | token 级别 GPU 前缀缓存命中率 | 百分比 | | KV缓存 | `fastdeploy:prefix_cache_token_num` | Counter | 前缀缓存token总数 | 个 | | KV缓存 | `fastdeploy:prefix_gpu_cache_token_num` | Counter | 位于 GPU 上的前缀缓存 token 总数 | 个 | -| KV缓存 | `fastdeploy:prefix_cpu_cache_token_num` | Counter | 位于 GPU 上的前缀缓存 token 总数 | 个 | +| KV缓存 | `fastdeploy:prefix_cpu_cache_token_num` | Counter | 位于 CPU 上的前缀缓存 token 总数 | 个 | | KV缓存 | `fastdeploy:available_gpu_block_num` | Gauge | 缓存中可用的 GPU 块数量(包含尚未正式释放的前缀缓存块)| 个 | | KV缓存 | `fastdeploy:free_gpu_block_num` | Gauge | 缓存中的可用块数 | 个 | | KV缓存 | `fastdeploy:max_gpu_block_num` | Gauge | 服务启动时确定的总块数 | 个 | diff --git a/fastdeploy/cache_manager/cache_messager.py b/fastdeploy/cache_manager/cache_messager.py index 4ba583cbfa9..65f702b794a 100644 --- a/fastdeploy/cache_manager/cache_messager.py +++ b/fastdeploy/cache_manager/cache_messager.py @@ -1051,7 +1051,10 @@ def main(): args = parse_args() rank_id = args.rank + args.local_data_parallel_id * args.mp_num - logger = get_logger("cache_messager", f"cache_messager_tprank{args.rank}.log") + if args.mp_num > 1: + logger = get_logger("cache_messager", f"cache_messager_{rank_id}.log") + else: + logger = get_logger("cache_messager", "cache_messager.log") logger.info("create cache messager...") logger.info(f"{args}") diff --git a/fastdeploy/cache_manager/cache_tasks.py b/fastdeploy/cache_manager/cache_tasks.py new file mode 100644 index 00000000000..fe15263827a --- /dev/null +++ b/fastdeploy/cache_manager/cache_tasks.py @@ -0,0 +1,37 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +from dataclasses import dataclass +from typing import List + + +@dataclass(frozen=True, kw_only=True) +class CacheTask: + task_id: str + keys: List[str] + token_ids: List[int] + gpu_block_ids: List[int] + + +@dataclass(frozen=True, kw_only=True) +class ReadStorageTask(CacheTask): + start_read_block_idx: int + timeout: float = 30.0 + + +@dataclass(frozen=True, kw_only=True) +class WriteStorageTask(CacheTask): + timeout: float = 30.0 diff --git a/fastdeploy/cache_manager/cache_transfer_manager.py b/fastdeploy/cache_manager/cache_transfer_manager.py index 56360d3bf38..5da4aa17027 100644 --- a/fastdeploy/cache_manager/cache_transfer_manager.py +++ b/fastdeploy/cache_manager/cache_transfer_manager.py @@ -29,6 +29,7 @@ from fastdeploy import envs from fastdeploy.cache_manager.cache_data import CacheStatus +from fastdeploy.cache_manager.cache_tasks import ReadStorageTask, WriteStorageTask from fastdeploy.cache_manager.ops import ( cuda_host_alloc, cuda_host_free, @@ -40,7 +41,7 @@ swap_cache_layout, unset_data_ipc, ) -from fastdeploy.cache_manager.transfer_factory import MooncakeStore +from fastdeploy.cache_manager.transfer_factory import AttentionStore, MooncakeStore from fastdeploy.config import SpeculativeConfig from fastdeploy.inter_communicator import EngineCacheQueue, IPCSignal, KVCacheStatus from fastdeploy.platforms import current_platform @@ -58,6 +59,7 @@ def parse_args(): default="mixed", help="splitwise role, can be decode, prefill or mixed", ) + parser.add_argument("--model_id", type=str, default="default", help="model id") parser.add_argument("--rank", type=int, default=0, help="local tp rank") parser.add_argument("--device_id", type=int, default=0, help="device id") parser.add_argument("--max_model_len", type=int, default=32768, help="max model length") @@ -109,7 +111,7 @@ def parse_args(): "--kvcache_storage_backend", type=str, default=None, - choices=["mooncake", "none"], + choices=["mooncake", "attention_store", "none"], help="The storage backend for kvcache storage. If not set, storage backend is disabled.", ) parser.add_argument( @@ -133,8 +135,6 @@ def __init__(self, args): """ 初始化CacheTransferManager """ - device = args.device_id - rank = args.rank self.gpu_cache_kvs = {} self.cpu_cache_kvs = {} self.gpu_cache_k_tensors = [] @@ -142,11 +142,31 @@ def __init__(self, args): self.gpu_cache_scales_k_tensors = [] self.gpu_cache_scales_v_tensors = [] self.speculative_config = SpeculativeConfig(args.speculative_config) + + # parse kv cache shape self.key_cache_shape = [int(i) for i in args.key_cache_shape.split(",")] self.value_cache_shape = [] if args.value_cache_shape: self.value_cache_shape = [int(i) for i in args.value_cache_shape.split(",")] + + # extract kv cache shape into fields self.num_gpu_blocks = self.key_cache_shape[0] + self.head_num = self.key_cache_shape[1] + self.block_size = self.key_cache_shape[2] + self.head_dim = self.key_cache_shape[3] + + # compute cache bytes + self.cache_dtype = args.cache_dtype + self.cache_bytes = self._get_cache_bytes(self.cache_dtype) + + # extract other arg values + self.model_id = args.model_id + self.n_ranks = args.mp_num + self.rank = args.rank + self.device = args.device_id + self.num_layers = args.num_layers + self.ipc_suffix = args.ipc_suffix + self.local_data_parallel_id = args.local_data_parallel_id self.num_extra_layers = self.speculative_config.num_extra_cache_layer self.num_extra_layer_gpu_blocks = int(self.num_gpu_blocks * self.speculative_config.num_gpu_block_expand_ratio) paddle.set_default_dtype(args.default_dtype) @@ -158,18 +178,13 @@ def __init__(self, args): self.timeout_thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=2) self.transfer_task_queue = queue.Queue() # 用来接收传输任务 self.tansfer_done_queue = queue.Queue() # 用来告知任务执行完毕 - self.n_ranks = args.mp_num - self.rank = rank - self.device = device - self.ipc_suffix = args.ipc_suffix - self.cache_dtype = args.cache_dtype address = (args.pod_ip, args.cache_queue_port) self.cache_task_queue = EngineCacheQueue( address=address, is_server=False, num_client=args.mp_num, - client_id=rank, + client_id=self.rank, local_data_parallel_id=args.local_data_parallel_id, ) @@ -223,8 +238,22 @@ def __init__(self, args): self.storage_backend = MooncakeStore(tp_rank=self.rank) self._init_storage_buffer(args) logger.info("Initialized mooncake store successfully") + elif args.kvcache_storage_backend == "attention_store": + logger.info("Start initialize attention store...") + self.storage_backend = AttentionStore( + namespace=self.model_id, + shard_id=self.rank, + shard_num=self.n_ranks, + layer_num=self.num_layers + self.num_extra_layers, + block_token_size=self.block_size, + bytes_per_shard_layer_per_block=self.head_num * self.block_size * self.head_dim * self.cache_bytes, + device_id=self.device, + dp_id=self.local_data_parallel_id, + ) + logger.info("Initialized attention store successfully!") else: raise NotImplementedError(f"Unsupported storage backend: {args.kvcache_storage_backend}") + self.storage_backend_type = args.kvcache_storage_backend if args.write_policy not in ["write_through"]: raise ValueError(f"Invalid write policy: {args.write_policy}") @@ -246,18 +275,16 @@ def _init_storage_buffer(self, args): cache layout: layer_num * [block_num, head_num, block_size, head_dim] buffer layout: [block_num, layer_num, head_num, block_size, head_dim] """ - layer_num = args.num_layers + self.num_extra_layers - head_num = self.key_cache_shape[1] - block_size = self.key_cache_shape[2] - head_dim = self.key_cache_shape[3] - block_num = (args.max_model_len + block_size - 1) // block_size + layer_num = self.num_layers + self.num_extra_layers + block_num = (args.max_model_len + self.block_size - 1) // self.block_size logger.info( f"Creating cache buffer for storage with shape: " - f"[{block_num}, {layer_num}, {head_num}, {block_size}, {head_dim}]" + f"[{block_num}, {layer_num}, {self.head_num}, {self.block_size}, {self.head_dim}]" ) - self.cache_bytes = self._get_cache_bytes(self.cache_dtype) - self.storage_buffer_stride_bytes = layer_num * head_num * block_size * head_dim * self.cache_bytes + self.storage_buffer_stride_bytes = ( + layer_num * self.head_num * self.block_size * self.head_dim * self.cache_bytes + ) total_bytes = block_num * self.storage_buffer_stride_bytes * 2 # key and value logger.info(f"Creating cpu buffer cache for alllayers: {total_bytes / 1024 ** 3:.2f}GB") @@ -296,8 +323,8 @@ def _init_gpu_cache(self, args): logger.info(f"[rank {self.rank}/{self.n_ranks}] Initializing kv cache for all layers.") set_device(self.device) - for i in range(args.num_layers + self.num_extra_layers): - num_gpu_blocks = self.num_gpu_blocks if i < args.num_layers else self.num_extra_layer_gpu_blocks + for i in range(self.num_layers + self.num_extra_layers): + num_gpu_blocks = self.num_gpu_blocks if i < self.num_layers else self.num_extra_layer_gpu_blocks key_name = f"key_caches_{i}_rank{self.rank}.device{self.device}" val_name = f"value_caches_{i}_rank{self.rank}.device{self.device}" key_cache_scales_name = f"key_cache_scales_{i}_rank{self.rank}.device{self.device}" @@ -415,7 +442,7 @@ def _init_cpu_cache(self, args): self.v_dst_ptrs = [] self.k_scales_ptrs = [] self.v_scales_ptrs = [] - for i in range(args.num_layers + self.num_extra_layers): + for i in range(self.num_layers + self.num_extra_layers): key_name = f"key_caches_{i}_rank{self.rank}" val_name = f"value_caches_{i}_rank{self.rank}" key_cache_scales_name = f"key_cache_scales_{i}_rank{self.rank}" @@ -446,228 +473,283 @@ def _get_cache_bytes(self, cache_dtype): raise ValueError(f"Unsupported cache dtype: {cache_dtype}") return cache_bytes - def _storage_exist_block_num(self, k_keys: List[str], v_keys: List[str]): + def _run_read_storage( + self, + task_id: str, + token_ids: List[int], + start_read_block_idx: int, + k_cache_keys: List[str], + v_cache_keys: List[str], + gpu_block_ids: List[int], + cpu_block_ids: List[int], + timeout: float, + ): """ - Given the k_keys and v_keys, get the valid blocks number that - can be prefetched from storage backend. + Read storage data from the given blocks to the corresponding cache tensors on the current rank's GPU. """ - assert len(k_keys) == len(v_keys), "k_keys and v_keys must have the same length." - result = self.storage_backend.exists(k_keys + v_keys) - - # only consider the case when both key and value exist - num = 0 - for k, v in zip(k_keys, v_keys): - if result[k] and result[v]: - num += 1 - return num - - def _run_read_storage(self, k_cache_keys, v_cache_keys, gpu_block_ids, cpu_block_ids): try: - logger.debug( - f"_run_read_storage, key_hash_keys_num: {len(k_cache_keys)}, " - f"value_hash_keys_num: {len(v_cache_keys)}, gpu_block_ids_num: {len(gpu_block_ids)}, " - f"cpu_block_ids_num: {len(cpu_block_ids)}" - ) + if self.storage_backend_type == "mooncake": + block_num = len(gpu_block_ids) + keys = k_cache_keys + v_cache_keys + k_cache_ptrs = [ + self.storage_key_read_buffer + i * self.storage_buffer_stride_bytes for i in cpu_block_ids + ] + v_cache_ptrs = [ + self.storage_value_read_buffer + i * self.storage_buffer_stride_bytes for i in cpu_block_ids + ] + kv_cache_ptrs = k_cache_ptrs + v_cache_ptrs + kv_block_sizes = [self.storage_buffer_stride_bytes] * block_num * 2 # key and value + start_time = time.time() + result = self.storage_backend.batch_get( + keys, target_locations=kv_cache_ptrs, target_sizes=kv_block_sizes + ) + read_cost_time = time.time() - start_time + + k_result, v_result = result[:block_num], result[block_num:] + success_block_num = 0 + for k, v in zip(k_result, v_result): + if k > 0 and v > 0: + success_block_num += 1 + logger.debug(f"_run_read_storage, success_block_num: {success_block_num}") + valid_gpu_block_ids = gpu_block_ids[:success_block_num] + valid_cpu_block_ids = cpu_block_ids[:success_block_num] + + mode = 1 # cpu ==> gpu + start_time = time.time() + swap_cache_layout( + self.gpu_cache_k_tensors, + self.storage_key_read_buffer, + self.key_cache_shape, + valid_gpu_block_ids, + valid_cpu_block_ids, + self.device, + mode, + ) + swap_cache_layout( + self.gpu_cache_v_tensors, + self.storage_value_read_buffer, + self.value_cache_shape, + valid_gpu_block_ids, + valid_cpu_block_ids, + self.device, + mode, + ) + swap_cost_time = time.time() - start_time + logger.debug( + f"_run_read_storage, swap_cost_time: {swap_cost_time:.6f}s, read_cost_time: {read_cost_time:.6f}s" + ) - block_num = len(gpu_block_ids) - keys = k_cache_keys + v_cache_keys - k_cache_ptrs = [self.storage_key_read_buffer + i * self.storage_buffer_stride_bytes for i in cpu_block_ids] - v_cache_ptrs = [ - self.storage_value_read_buffer + i * self.storage_buffer_stride_bytes for i in cpu_block_ids - ] - kv_cache_ptrs = k_cache_ptrs + v_cache_ptrs - kv_block_sizes = [self.storage_buffer_stride_bytes] * block_num * 2 # key and value - start_time = time.time() - result = self.storage_backend.batch_get(keys, target_locations=kv_cache_ptrs, target_sizes=kv_block_sizes) - read_cost_time = time.time() - start_time - - k_result, v_result = result[:block_num], result[block_num:] - success_block_num = 0 - for k, v in zip(k_result, v_result): - if k > 0 and v > 0: - success_block_num += 1 - logger.debug(f"_run_read_storage, success_block_num: {success_block_num}") - valid_gpu_block_ids = gpu_block_ids[:success_block_num] - valid_cpu_block_ids = cpu_block_ids[:success_block_num] - - mode = 1 # cpu ==> gpu - start_time = time.time() - swap_cache_layout( - self.gpu_cache_k_tensors, - self.storage_key_read_buffer, - self.key_cache_shape, - valid_gpu_block_ids, - valid_cpu_block_ids, - self.device, - mode, - ) - swap_cache_layout( - self.gpu_cache_v_tensors, - self.storage_value_read_buffer, - self.value_cache_shape, - valid_gpu_block_ids, - valid_cpu_block_ids, - self.device, - mode, - ) - swap_cost_time = time.time() - start_time - logger.debug( - f"_run_read_storage, swap_cost_time: {swap_cost_time:.6f}s, read_cost_time: {read_cost_time:.6f}s" - ) + elif self.storage_backend_type == "attention_store": + key_cache = [] + val_cache = [] + for i in range(self.num_layers + self.num_extra_layers): + key_cache.append(self.gpu_cache_kvs[f"key_caches_{i}_rank{self.rank}.device{self.device}"]) + val_cache.append(self.gpu_cache_kvs[f"value_caches_{i}_rank{self.rank}.device{self.device}"]) + + start_time = time.time() + read_block_num = self.storage_backend.read( + task_id, key_cache, val_cache, token_ids, gpu_block_ids, start_read_block_idx, timeout + ) + read_cost_time = time.time() - start_time + valid_gpu_block_ids = gpu_block_ids[:read_block_num] + logger.debug(f"_run_read_storage, read_cost_time: {read_cost_time:.6f}s") return valid_gpu_block_ids + except Exception as e: logger.error( - f"[rank {self.rank}/{self.n_ranks}] An error occurred in _run_read_storage: " - f"error:{e}, {traceback.format_exc()}" + f"An error occurred in _run_read_storage, " f"error: {e}, traceback:\n{traceback.format_exc()}" ) raise - def read_storage_task(self, task_id, keys, gpu_block_ids, timeout=0.1): + def read_storage_task(self, task: ReadStorageTask): """Read cache from the storage backend to the GPU memory.""" try: - logger.debug( - f"read_storage_task, task id: {task_id}, hash_keys_num: {len(keys)}, " - f"gpu_block_ids_num: {len(gpu_block_ids)}, timeout: {timeout}" - ) - k_cache_keys = [f"{key}_key_{self.rank}" for key in keys] - v_cache_keys = [f"{key}_value_{self.rank}" for key in keys] - match_block_num = self._storage_exist_block_num(k_cache_keys, v_cache_keys) - logger.debug(f"read_storage_task, match {match_block_num} blocks from storage for task id: {task_id}") + gpu_block_ids = task.gpu_block_ids.copy() + cpu_block_ids = [i for i in range(len(gpu_block_ids))] + k_cache_keys = [f"{key}_key_{self.rank}" for key in task.keys] + v_cache_keys = [f"{key}_value_{self.rank}" for key in task.keys] + match_block_num = 0 + if self.storage_backend_type == "mooncake": + match_block_num = self.storage_backend.query(k_cache_keys, v_cache_keys) + elif self.storage_backend_type == "attention_store": + match_block_num = self.storage_backend.query( + task.task_id, task.token_ids, task.start_read_block_idx, task.timeout + ) + logger.info(f"Matched {match_block_num} blocks in cache storage for read task {task.task_id}") k_cache_keys = k_cache_keys[:match_block_num] v_cache_keys = v_cache_keys[:match_block_num] gpu_block_ids = gpu_block_ids[:match_block_num] - cpu_block_ids = [i for i in range(match_block_num)] + cpu_block_ids = cpu_block_ids[:match_block_num] valid_gpu_block_ids = [] - if match_block_num > 0: # TODO: support timeout with actual block count try: valid_gpu_block_ids = self._run_read_storage( - k_cache_keys, v_cache_keys, gpu_block_ids, cpu_block_ids + task.task_id, + task.token_ids[: match_block_num * self.block_size], + task.start_read_block_idx, + k_cache_keys, + v_cache_keys, + gpu_block_ids, + cpu_block_ids, + task.timeout, ) logger.info( - f"read_storage_task, finish loading {match_block_num} blocks from storage for task {task_id}." + f"Successfully read {len(valid_gpu_block_ids)} blocks from cache storage for task {task.task_id}" ) except Exception as e: - logger.error(f"[rank {self.rank}/{self.n_ranks}] An error occurred: {task_id} {e}") + logger.error(f"Failed to read cache for task {task.task_id}, error: {e}") valid_gpu_block_ids = [] - result = (CacheStatus.STORAGE2GPU, task_id, keys, valid_gpu_block_ids) + result = (CacheStatus.STORAGE2GPU, task.task_id, task.keys, valid_gpu_block_ids) self.cache_task_queue.swap_storage_to_gpu_barrier.wait() self.cache_task_queue.swap_storage_to_gpu_barrier.reset() self.cache_task_queue.put_transfer_done_signal(result) - logger.debug(f"read_storage_task: put_transfer_done_signal {result}") - logger.info( - f"read_storage_task: put_transfer_done_signal for transfer_task_id {task_id}, " - f"valid block num {len(valid_gpu_block_ids)}" - ) + logger.debug(f"read_storage_task: put transfer done signal for {task.task_id}") + except Exception as e: logger.error( - f"[rank {self.rank}/{self.n_ranks}] An error occurred in read_storage_task: " - f"task_id: {task_id}, error:{e}, {traceback.format_exc()}" + f"An error occurred in read_storage_task: " + f"task_id: {task.task_id}, error:{e}, {traceback.format_exc()}" ) - def _run_write_back_storage(self, k_cache_keys, v_cache_keys, gpu_block_ids, cpu_block_ids): + def _run_write_back_storage( + self, + task_id, + token_ids, + start_write_block_idx, + k_cache_keys, + v_cache_keys, + gpu_block_ids, + cpu_block_ids, + timeout, + ): try: - logger.debug( - f"_run_write_back_storage, k_cache_keys: {k_cache_keys}, v_cache_keys: {v_cache_keys}, " - f"gpu_block_ids: {gpu_block_ids}" - ) - key_cache_size = [ - self.key_cache_shape[0], - self.key_cache_shape[1], - self.key_cache_shape[2], - self.key_cache_shape[3], - ] + if self.storage_backend_type == "mooncake": + key_cache_size = [ + self.key_cache_shape[0], + self.key_cache_shape[1], + self.key_cache_shape[2], + self.key_cache_shape[3], + ] + mode = 0 # gpu ==> cpu + start_time = time.time() + swap_cache_layout( + self.gpu_cache_k_tensors, + self.storage_key_write_buffer, + key_cache_size, + gpu_block_ids, + cpu_block_ids, + self.device, + mode, + ) + swap_cache_layout( + self.gpu_cache_v_tensors, + self.storage_value_write_buffer, + key_cache_size, + gpu_block_ids, + cpu_block_ids, + self.device, + mode, + ) + swap_cost_time = time.time() - start_time - mode = 0 # gpu ==> cpu - start_time = time.time() - swap_cache_layout( - self.gpu_cache_k_tensors, - self.storage_key_write_buffer, - key_cache_size, - gpu_block_ids, - cpu_block_ids, - self.device, - mode, - ) - swap_cache_layout( - self.gpu_cache_v_tensors, - self.storage_value_write_buffer, - key_cache_size, - gpu_block_ids, - cpu_block_ids, - self.device, - mode, - ) - swap_cost_time = time.time() - start_time + block_num = len(gpu_block_ids) + keys = k_cache_keys + v_cache_keys + k_cache_ptrs = [ + self.storage_key_write_buffer + i * self.storage_buffer_stride_bytes for i in cpu_block_ids + ] + v_cache_ptrs = [ + self.storage_value_write_buffer + i * self.storage_buffer_stride_bytes for i in cpu_block_ids + ] + kv_cache_ptrs = k_cache_ptrs + v_cache_ptrs + kv_block_sizes = [self.storage_buffer_stride_bytes] * block_num * 2 # key and value + + start_time = time.time() + self.storage_backend.batch_set(keys, target_locations=kv_cache_ptrs, target_sizes=kv_block_sizes) + write_cost_time = time.time() - start_time + + logger.debug( + f"_run_write_back_storage, swap_cost_time: {swap_cost_time:.6f}s, write_cost_time: {write_cost_time:.6f}s" + ) + return block_num + + elif self.storage_backend_type == "attention_store": + key_cache = [] + val_cache = [] + for i in range(self.num_layers + self.num_extra_layers): + key_cache.append(self.gpu_cache_kvs[f"key_caches_{i}_rank{self.rank}.device{self.device}"]) + val_cache.append(self.gpu_cache_kvs[f"value_caches_{i}_rank{self.rank}.device{self.device}"]) + + start_time = time.time() + write_block_num = self.storage_backend.write( + task_id, key_cache, val_cache, token_ids, gpu_block_ids, start_write_block_idx, timeout + ) + write_cost_time = time.time() - start_time + logger.debug(f"_run_write_back_storage, write_cost_time: {write_cost_time:.6f}s") + return write_block_num - block_num = len(gpu_block_ids) - keys = k_cache_keys + v_cache_keys - k_cache_ptrs = [ - self.storage_key_write_buffer + i * self.storage_buffer_stride_bytes for i in cpu_block_ids - ] - v_cache_ptrs = [ - self.storage_value_write_buffer + i * self.storage_buffer_stride_bytes for i in cpu_block_ids - ] - kv_cache_ptrs = k_cache_ptrs + v_cache_ptrs - kv_block_sizes = [self.storage_buffer_stride_bytes] * block_num * 2 # key and value - start_time = time.time() - self.storage_backend.batch_set(keys, target_locations=kv_cache_ptrs, target_sizes=kv_block_sizes) - write_cost_time = time.time() - start_time - - logger.debug( - f"_run_write_back_storage, swap_cost_time: {swap_cost_time:.6f}s, write_cost_time: {write_cost_time:.6f}s" - ) except Exception as e: logger.error( - f"[rank {self.rank}/{self.n_ranks}] An error occurred in _run_write_back_storage: " - f"error:{e}, {traceback.format_exc()}" + f"An error occurred in _run_write_back_storage, " f"error: {e}, traceback:\n{traceback.format_exc()}" ) + return 0 - def write_back_storage_task(self, task_id, keys, gpu_block_ids, timeout=0.1): + def write_back_storage_task(self, task: WriteStorageTask): """ Write cache to the storage backend from the GPU memory. """ try: - logger.debug( - f"write cache to storage, keys: {keys}, gpu_block_ids: {gpu_block_ids}, " - f"task_id: {task_id}, timeout: {timeout}" - ) - - k_cache_keys = [f"{key}_key_{self.rank}" for key in keys] - v_cache_keys = [f"{key}_value_{self.rank}" for key in keys] - match_block_num = self._storage_exist_block_num(k_cache_keys, v_cache_keys) - - k_cache_keys = k_cache_keys[match_block_num:] - v_cache_keys = v_cache_keys[match_block_num:] - gpu_block_ids = gpu_block_ids[match_block_num:] + gpu_block_ids = task.gpu_block_ids.copy() cpu_block_ids = [i for i in range(len(gpu_block_ids))] - - if len(k_cache_keys) == 0: - logger.info(f"No uncached keys found for task {task_id}") + k_cache_keys = [f"{key}_key_{self.rank}" for key in task.keys] + v_cache_keys = [f"{key}_value_{self.rank}" for key in task.keys] + + match_block_num = 0 + if self.storage_backend_type == "mooncake": + match_block_num = self.storage_backend.query(k_cache_keys, v_cache_keys, task.timeout) + elif self.storage_backend_type == "attention_store": + match_block_num = self.storage_backend.query(task.task_id, task.token_ids, 0, task.timeout) + logger.info(f"Matched {match_block_num} blocks in cache storage for write task {task.task_id}") + + if match_block_num >= len(k_cache_keys): + logger.info(f"No uncached keys found for task {task.task_id}") gpu_block_ids = [] else: try: + k_cache_keys = k_cache_keys[match_block_num:] + v_cache_keys = v_cache_keys[match_block_num:] + gpu_block_ids = gpu_block_ids[match_block_num:] + cpu_block_ids = cpu_block_ids[match_block_num:] # TODO: support timeout with actual block count - self._run_write_back_storage(k_cache_keys, v_cache_keys, gpu_block_ids, cpu_block_ids) + write_block_num = self._run_write_back_storage( + task.task_id, + task.token_ids, + match_block_num, + k_cache_keys, + v_cache_keys, + gpu_block_ids, + cpu_block_ids, + task.timeout, + ) + logger.info( + f"Successfully wrote {write_block_num} blocks to cache storage for task {task.task_id}" + ) except Exception as e: logger.error(f"Error in write back storage task: {e}") gpu_block_ids = [] - result = (CacheStatus.GPU2STORAGE, task_id, keys, gpu_block_ids) + result = (CacheStatus.GPU2STORAGE, task.task_id, task.keys, gpu_block_ids) self.cache_task_queue.swap_to_storage_barrier.wait() if self.rank == 0: # 只有当rank为0时执行同步操作 self.cache_task_queue.swap_to_storage_barrier.reset() self.cache_task_queue.put_transfer_done_signal(result) # 发送传输完成信号 logger.debug(f"write_back_storage_task: put_transfer_done_signal {result}") - logger.info(f"write_back_storage_task: put_transfer_done_signal for transfer_task_id {task_id}") except Exception as e: logger.error( - f"[rank {self.rank}/{self.n_ranks}] An error occurred in write_back_storage_task: " - f"error:{e}, {traceback.format_exc()}" + f"An error occurred in write_back_storage_task, " f"error: {e}, traceback:\n{traceback.format_exc()}" ) def _do_swap_to_cpu_task( @@ -759,12 +841,12 @@ def do_data_transfer(self): self.cache_task_queue.barrier1.reset() if self.cache_task_broadcast_signal.value[0] == 1: data, read_finish = self.cache_task_queue.get_transfer_task() - logger.debug(f"transfer data: get_transfer_task {data}") + logger.debug(f"do_data_transfer: {data}") if read_finish: self.cache_task_broadcast_signal.value[0] = 0 - event_type, transfer_task_id = data[0], data[1] + event_type, event_args = data[0], data[1:] if event_type.value == CacheStatus.SWAP2CPU.value: - swap_node_ids, gpu_block_id, cpu_block_id = data[2:] + transfer_task_id, swap_node_ids, gpu_block_id, cpu_block_id = event_args self.swap_to_cpu_thread_pool.submit( self._do_swap_to_cpu_task, swap_node_ids, @@ -774,7 +856,7 @@ def do_data_transfer(self): transfer_task_id, ) elif event_type.value == CacheStatus.SWAP2GPU.value: - swap_node_ids, gpu_block_id, cpu_block_id = data[2:] + transfer_task_id, swap_node_ids, gpu_block_id, cpu_block_id = event_args self.swap_to_gpu_thread_pool.submit( self._do_swap_to_gpu_task, swap_node_ids, @@ -784,22 +866,16 @@ def do_data_transfer(self): transfer_task_id, ) elif event_type.value == CacheStatus.STORAGE2GPU.value: - hash_keys, gpu_block_ids, timeout = data[2:] + read_storage_task = event_args[0] self.read_storage_thread_pool.submit( self.read_storage_task, - transfer_task_id, - hash_keys, - gpu_block_ids, - timeout, + read_storage_task, ) elif event_type.value == CacheStatus.GPU2STORAGE.value: - hash_keys, gpu_block_ids, timeout = data[2:] + write_storage_task = event_args[0] self.write_back_storage_thread_pool.submit( self.write_back_storage_task, - transfer_task_id, - hash_keys, - gpu_block_ids, - timeout, + write_storage_task, ) else: if self.n_ranks > 1: @@ -1047,7 +1123,11 @@ def main(): args = parse_args() rank_id = args.rank + args.local_data_parallel_id * args.mp_num - logger = get_logger("cache_transfer_manager", f"cache_transfer_manager_tprank{args.rank}.log") + if args.mp_num > 1: + logger = get_logger("cache_transfer", f"cache_transfer_{rank_id}.log") + else: + logger = get_logger("cache_transfer", "cache_transfer.log") + logger.info(f"args: {vars(args)}") set_device(args.device_id) try: diff --git a/fastdeploy/cache_manager/prefix_cache_manager.py b/fastdeploy/cache_manager/prefix_cache_manager.py index f5b81c02719..318fbccb975 100644 --- a/fastdeploy/cache_manager/prefix_cache_manager.py +++ b/fastdeploy/cache_manager/prefix_cache_manager.py @@ -31,7 +31,9 @@ from fastdeploy import envs from fastdeploy.cache_manager.cache_data import BlockNode, CacheStatus from fastdeploy.cache_manager.cache_metrics import CacheMetrics +from fastdeploy.cache_manager.cache_tasks import ReadStorageTask, WriteStorageTask from fastdeploy.cache_manager.ops import get_all_visible_devices +from fastdeploy.config import FDConfig from fastdeploy.engine.request import Request from fastdeploy.inter_communicator import EngineCacheQueue, IPCSignal, PrefixTreeStatus from fastdeploy.metrics.metrics import main_process_metrics @@ -47,7 +49,7 @@ class PrefixCacheManager: def __init__( self, - config, + config: FDConfig, tensor_parallel_size, splitwise_role="mixed", local_data_parallel_id=0, @@ -207,7 +209,6 @@ def launch_cache_manager( key_cache_shape, val_cache_shape = self._get_kv_cache_shape(cache_config.total_block_num) key_cache_shape = ",".join([str(i) for i in key_cache_shape]) val_cache_shape = ",".join([str(i) for i in val_cache_shape]) - logger.info(f"key_cache_shape {key_cache_shape} value_cache_shape {val_cache_shape}") if self.enable_splitwise: cache_messager_processes = self.launch_cache_messager( cache_config, @@ -273,6 +274,7 @@ def launch_cache_manager( + " NCCL_MAX_NCHANNELS=1 NCCL_BUFFSIZE=0" + f" FD_ENABLE_SWAP_SPACE_CLEARING={envs.FD_ENABLE_SWAP_SPACE_CLEARING}" + f" {sys.executable} {py_path}" + + f" --model_id {os.path.basename(self.config.model_config.model)}" + f" --device_id {int(device_ids[i])}" + f" --rank {i}" + f" --splitwise_role {self.splitwise_role}" @@ -390,7 +392,7 @@ def launch_cache_messager( + f" --ipc_suffix {ipc_suffix}" + f" --rdma_port {cache_config.local_rdma_comm_ports[i] if cache_config.local_rdma_comm_ports is not None else '0'}" + f" --speculative_config '{self.speculative_config.to_json_string()}'" - + f" >{log_dir}/launch_cache_messager_tprank{i}.log 2>&1" + + f" >{log_dir}/launch_cache_messager_{i}.log 2>&1" ) logger.info(f"Launch cache messager, command:{launch_cmd}") cache_messager_processes.append(subprocess.Popen(launch_cmd, shell=True, preexec_fn=os.setsid)) @@ -789,9 +791,15 @@ def request_match_blocks(self, task: Request, block_size, *args): f"start prefetch cache from storage, req_id: {req_id}, block num: {len(no_match_block_keys)}" ) start_time = time.time() - storage_matched_block_ids = self.issue_prefetch_storage_task( - req_id, no_match_block_keys, gpu_recv_storage_block_ids + read_storage_task = ReadStorageTask( + task_id=req_id, + keys=no_match_block_keys, + token_ids=input_token_ids, + gpu_block_ids=gpu_recv_storage_block_ids, + start_read_block_idx=match_token_num // block_size, ) + logger.debug(f"issue read storage task: {read_storage_task}") + storage_matched_block_ids = self.issue_prefetch_storage_task(read_storage_task) storage_matched_block_num = len(storage_matched_block_ids) storage_match_token_num = storage_matched_block_num * block_size cost_time = time.time() - start_time @@ -1006,6 +1014,12 @@ def write_cache_to_storage(self, request: Request): if self.kvcache_storage_backend is None: return + token_ids = request.prompt_token_ids + if isinstance(token_ids, np.ndarray): + token_ids = token_ids.tolist() + if self.config.cache_config.enable_output_caching: + token_ids += request.output_token_ids + req_id = request.request_id keys = [] node = self.req_leaf_map[req_id] @@ -1018,24 +1032,33 @@ def write_cache_to_storage(self, request: Request): gpu_block_ids = request.block_tables[: len(keys)] logger.info(f"start write cache back to storage, req_id: {req_id}, block num: {len(keys)}") + write_storage_task = WriteStorageTask( + task_id=req_id, + keys=keys, + token_ids=token_ids, + gpu_block_ids=gpu_block_ids, + ) + logger.debug(f"issue write storage task: {write_storage_task}") tic = time.time() - self.issue_write_back_storage_task(req_id=req_id, hash_keys=keys, gpu_block_ids=gpu_block_ids, is_sync=True) + self.issue_write_back_storage_task(write_storage_task, is_sync=True) cost_time = time.time() - tic logger.info(f"finish write cache back to storage, req_id: {req_id}, cost_time: {cost_time:.6f}s") - def issue_write_back_storage_task(self, req_id, hash_keys, gpu_block_ids, is_sync=True, timeout=0.5): + def issue_write_back_storage_task(self, task: WriteStorageTask, is_sync=True): if self.kvcache_storage_backend is None: return - if len(hash_keys) != len(gpu_block_ids): - err_msg = f"write_back_storage error: hash_keys({len(hash_keys)}) != gpu_block_ids({len(gpu_block_ids)})" + if len(task.keys) != len(task.gpu_block_ids): + err_msg = ( + f"write_back_storage error: hash_keys({len(task.keys)}) != gpu_block_ids({len(task.gpu_block_ids)})" + ) logger.error(err_msg) raise ValueError(err_msg) - self.task_write_back_event[req_id] = Event() - self.cache_task_queue.put_transfer_task((CacheStatus.GPU2STORAGE, req_id, hash_keys, gpu_block_ids, timeout)) + self.task_write_back_event[task.task_id] = Event() + self.cache_task_queue.put_transfer_task((CacheStatus.GPU2STORAGE, task)) if is_sync: - self.wait_write_storage_task(req_id) + self.wait_write_storage_task(task.task_id) def wait_write_storage_task(self, req_id): """ @@ -1045,16 +1068,19 @@ def wait_write_storage_task(self, req_id): self.task_write_back_event[req_id].wait() del self.task_write_back_event[req_id] - def issue_prefetch_storage_task(self, req_id, hash_keys, gpu_block_ids, is_sync=True, timeout=0.5): + def issue_prefetch_storage_task(self, task: ReadStorageTask, is_sync=True): """ Prefetch cache from storage task """ + if self.kvcache_storage_backend is None: + return [] + storage_block_ids = [] - self.task_prefetch_event[req_id] = Event() + self.task_prefetch_event[task.task_id] = Event() # issue task to cache_transfer_manager - self.cache_task_queue.put_transfer_task((CacheStatus.STORAGE2GPU, req_id, hash_keys, gpu_block_ids, timeout)) + self.cache_task_queue.put_transfer_task((CacheStatus.STORAGE2GPU, task)) if is_sync: - storage_block_ids = self.wait_prefetch_storage_task(req_id) + storage_block_ids = self.wait_prefetch_storage_task(task.task_id) return storage_block_ids def wait_prefetch_storage_task(self, req_id): diff --git a/fastdeploy/cache_manager/transfer_factory/__init__.py b/fastdeploy/cache_manager/transfer_factory/__init__.py index f9c7d5dc979..34561f15ba4 100644 --- a/fastdeploy/cache_manager/transfer_factory/__init__.py +++ b/fastdeploy/cache_manager/transfer_factory/__init__.py @@ -17,7 +17,7 @@ from fastdeploy.platforms import current_platform from .kvcache_storage import KVCacheStorage -from .mooncake_store import MooncakeStore +from .mooncake_store import AttentionStore, MooncakeStore from .rdma_cache_transfer import RDMACommManager if current_platform.is_cuda(): @@ -31,4 +31,5 @@ "RDMACommManager", "KVCacheStorage", "MooncakeStore", + "AttentionStore", ] diff --git a/fastdeploy/cache_manager/transfer_factory/kvcache_storage.py b/fastdeploy/cache_manager/transfer_factory/kvcache_storage.py index 2345c061aaf..2428babf4de 100644 --- a/fastdeploy/cache_manager/transfer_factory/kvcache_storage.py +++ b/fastdeploy/cache_manager/transfer_factory/kvcache_storage.py @@ -95,3 +95,10 @@ def clear(self) -> bool: Clear all keys in storage """ pass + + @abstractmethod + def query(self) -> int: + """ + Query the number of blocks stored in the storage. + """ + pass diff --git a/fastdeploy/cache_manager/transfer_factory/mooncake_store/__init__.py b/fastdeploy/cache_manager/transfer_factory/mooncake_store/__init__.py index 1de4084be6b..00cbd4acf7b 100644 --- a/fastdeploy/cache_manager/transfer_factory/mooncake_store/__init__.py +++ b/fastdeploy/cache_manager/transfer_factory/mooncake_store/__init__.py @@ -14,6 +14,7 @@ # limitations under the License. """ +from .attention_store import AttentionStore from .mooncake_store import MooncakeStore -__all__ = ["MooncakeStore"] +__all__ = ["MooncakeStore", "AttentionStore"] diff --git a/fastdeploy/cache_manager/transfer_factory/mooncake_store/attention_store.py b/fastdeploy/cache_manager/transfer_factory/mooncake_store/attention_store.py new file mode 100644 index 00000000000..c67ac22574d --- /dev/null +++ b/fastdeploy/cache_manager/transfer_factory/mooncake_store/attention_store.py @@ -0,0 +1,209 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +import time +import traceback +from dataclasses import dataclass +from typing import List + +import paddle + +from fastdeploy.cache_manager.transfer_factory.kvcache_storage import ( + KVCacheStorage, + logger, +) + +try: + from attentionstore_sdk.sdk import AttentionStoreSDK, Tokens + from attentionstore_sdk.utils.err import AttentionStoreSDKError + + _ATTENTIONSTORE_AVAILABLE = True +except Exception: + AttentionStoreSDK = None + Tokens = None + AttentionStoreSDKError = None + _ATTENTIONSTORE_AVAILABLE = False + + +@dataclass +class AttentionStoreConfig: + namespace: str = "default_ns" + pod_name: str = "default_pod" + model_version: str = "v0" + shard_id: int = 0 + shard_num: int = 1 + layer_num: int = 1 + block_token_size: int = 64 + bytes_per_shard_layer_per_block: int = 1024 + device_id: int = 0 + dp_id: int = 0 + + +class AttentionStore(KVCacheStorage): + def __init__(self, **args): + + if not _ATTENTIONSTORE_AVAILABLE: + raise ImportError("Please install attentionstore_sdk to run Fastdeploy with attentionstore_sdk.") + + self.config = AttentionStoreConfig(**args) + + try: + logger.info(f"[INIT] Start initializing AttentionStoreSDK with config: {self.config}") + self.sdk = AttentionStoreSDK( + self.config.namespace, + self.config.pod_name, + self.config.model_version, + self.config.shard_id, + self.config.shard_num, + self.config.layer_num, + self.config.block_token_size, + self.config.bytes_per_shard_layer_per_block, + self.config.device_id, + self.config.dp_id, + ) + self.wait_for_sdk_ready(timeout=300, delta_t=5) + logger.info("[INIT] ✅ AttentionStore is initialized successfully!") + except Exception as e: + logger.error( + f"[INIT] ❌ AttentionStore initialization failed, error: {e}, traceback:\n{traceback.format_exc()}" + ) + + def wait_for_sdk_ready(self, timeout: float, delta_t: float): + t = 0 + while t < timeout: + try: + tokens = Tokens(list(range(self.config.block_token_size + 1)), self.config.block_token_size) + self.sdk.match(tokens, 0, delta_t) + return + except AttentionStoreSDKError as e: + if "cuda memory not ready" in str(e): + logger.debug("[INIT] cuda memory not ready, try again..") + time.sleep(delta_t) + t += delta_t + continue + else: + raise RuntimeError( + f"Unexpected exception during AttentionStoreSDK initialization: {e}\n{traceback.format_exc()}" + ) + raise TimeoutError(f"AttentionStoreSDK initialization timed out after {timeout} seconds") + + def read( + self, + task_id: str, + key_cache: List[paddle.Tensor], + val_cache: List[paddle.Tensor], + token_ids: List[int], + gpu_block_ids: List[int], + start_read_block_idx: int, + timeout: float = 30.0, + ): + logger.debug( + f"[READ BEGIN] task_id: {task_id} token_ids: {token_ids} gpu_block_ids: {gpu_block_ids} start_read_block_idx: {start_read_block_idx} timeout: {timeout}" + ) + tokens = Tokens(token_ids, self.config.block_token_size) + k_data_ptrs = [k.data_ptr() for k in key_cache] + v_data_ptrs = [v.data_ptr() for v in val_cache] + num = 0 + try: + num = self.sdk.read( + list(range(self.config.layer_num)), + tokens, + start_read_block_idx, + k_data_ptrs, + v_data_ptrs, + gpu_block_ids, + timeout, + ) + logger.debug(f"[READ END] task_id: {task_id} read_blocks: {num}") + except AttentionStoreSDKError: + logger.error( + f"[READ ERROR] failed to execute sdk read, task_id: {task_id}, traceback:\n{traceback.format_exc()}" + ) + return num + + def write( + self, + task_id: str, + key_cache: List[paddle.Tensor], + val_cache: List[paddle.Tensor], + token_ids: List[int], + gpu_block_ids: List[int], + start_write_block_idx: int, + timeout: float = 30.0, + ) -> int: + logger.debug( + f"[WRITE BEGIN] task_id: {task_id} token_ids: {token_ids} gpu_block_ids: {gpu_block_ids} start_write_block_idx: {start_write_block_idx} timeout: {timeout}" + ) + tokens = Tokens(token_ids, self.config.block_token_size) + k_data_ptrs = [k.data_ptr() for k in key_cache] + v_data_ptrs = [v.data_ptr() for v in val_cache] + num = 0 + try: + num = self.sdk.write( + list(range(self.config.layer_num)), + tokens, + start_write_block_idx, + k_data_ptrs, + v_data_ptrs, + gpu_block_ids, + timeout, + ) + logger.debug(f"[WRITE END] task_id: {task_id} written_blocks: {num}") + except AttentionStoreSDKError: + logger.error( + f"[WRITE ERROR] failed to execute sdk write, task_id: {task_id}, traceback:\n{traceback.format_exc()}" + ) + return num + + def query(self, task_id: str, token_ids: List[int], start_match_block_idx: int, timeout: float = 10.0): + """ + Given the input ids and starting index to match, get the valid blocks number that + can be prefetched from storage backend. + """ + logger.debug( + f"[QUERY BEGIN] task_id: {task_id} token_ids: {token_ids} start_match_block_idx: {start_match_block_idx} timeout: {timeout}" + ) + tokens = Tokens(token_ids, self.config.block_token_size) + num = 0 + try: + num = self.sdk.match(tokens, start_match_block_idx, timeout) + logger.debug(f"[QUERY END] task_id: {task_id} matched_blocks: {num}") + except AttentionStoreSDKError: + logger.error( + f"[QUERY ERROR] Failed to execute sdk match, task_id: {task_id}, traceback:\n{traceback.format_exc()}" + ) + return num + + def get(self, **kwargs): + raise NotImplementedError("AttentionStore does not support this method") + + def batch_get(self, **kwargs): + raise NotImplementedError("AttentionStore does not support this method") + + def set(self, **kwargs) -> bool: + raise NotImplementedError("AttentionStore does not support this method") + + def batch_set(self, **kwargs) -> bool: + raise NotImplementedError("AttentionStore does not support this method") + + def exists(self, keys: List[str]) -> bool: + raise NotImplementedError("AttentionStore does not support this method") + + def clear(self) -> bool: + raise NotImplementedError("AttentionStore does not support this method") + + def register_buffer(self, buffer_ptr, buffer_size, buffer_type="none_type") -> None: + raise NotImplementedError("AttentionStore does not support this method") diff --git a/fastdeploy/cache_manager/transfer_factory/mooncake_store/mooncake_store.py b/fastdeploy/cache_manager/transfer_factory/mooncake_store/mooncake_store.py index ac7d793ad61..ccb378ca94c 100644 --- a/fastdeploy/cache_manager/transfer_factory/mooncake_store/mooncake_store.py +++ b/fastdeploy/cache_manager/transfer_factory/mooncake_store/mooncake_store.py @@ -237,6 +237,21 @@ def exists(self, keys: List[str]): logger.debug(f"The exists fun processes {len(keys)} objects, cost_time: {cost_time:.3f}ms") return result + def query(self, k_keys: List[str], v_keys: List[str], timeout: float = 1.0): + """ + Given the k_keys and v_keys, get the valid blocks number that + can be prefetched from storage backend. + """ + assert len(k_keys) == len(v_keys), "k_keys and v_keys must have the same length." + result = self.exists(k_keys + v_keys) + + # only consider the case when both key and value exist + num = 0 + for k, v in zip(k_keys, v_keys): + if result[k] and result[v]: + num += 1 + return num + def delete(self, key, timeout=5) -> bool: while timeout: result = self.store.remove(key) diff --git a/fastdeploy/engine/args_utils.py b/fastdeploy/engine/args_utils.py index cda87372e44..af1e97d896f 100644 --- a/fastdeploy/engine/args_utils.py +++ b/fastdeploy/engine/args_utils.py @@ -629,7 +629,6 @@ def post_init_ports(name: str, ports: list, num_total_ports: int): for port in cur_dp_ports: assert is_port_available("0.0.0.0", port), f"Parameter `{name}`:{port} is already in use." - console_logger.debug(f"post init {name}: {ports}") return ports num_nodes = len(self.ips) if self.ips else 1 @@ -1077,7 +1076,7 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: cache_group.add_argument( "--kvcache-storage-backend", type=nullable_str, - choices=["mooncake"], + choices=["mooncake", "attention_store"], default=EngineArgs.kvcache_storage_backend, help="The storage backend for kvcache storage. Leave empty to disable.", ) diff --git a/fastdeploy/engine/common_engine.py b/fastdeploy/engine/common_engine.py index 8f98ace1a97..784ed71c684 100644 --- a/fastdeploy/engine/common_engine.py +++ b/fastdeploy/engine/common_engine.py @@ -225,10 +225,6 @@ def check_worker_initialize_status_func(res: dict): device_ids = self.cfg.parallel_config.device_ids.split(",") self.cache_manager_processes = self.start_cache_service(device_ids, self.ipc_signal_suffix) - # Set cache manager signal - if self.cfg.scheduler_config.splitwise_role != "mixed": - self.launched_cache_manager_signal.value[0] = 1 - # Worker launched self.check_worker_initialize_status_func_thread.join() if not result_container["worker_is_alive"]: diff --git a/fastdeploy/engine/engine.py b/fastdeploy/engine/engine.py index 5424b718dbe..cf1ca0e59ff 100644 --- a/fastdeploy/engine/engine.py +++ b/fastdeploy/engine/engine.py @@ -181,10 +181,6 @@ def check_worker_initialize_status_func(res: dict): device_ids = self.cfg.parallel_config.device_ids.split(",") self.cache_manager_processes = self.engine.start_cache_service(device_ids, self.ipc_signal_suffix) - # Launch components: scheduler, cache_manager, expert_service et.al. - if self.cfg.scheduler_config.splitwise_role != "mixed": - self.launched_cache_manager_signal.value[0] = 1 - if self.cfg.scheduler_config.splitwise_role != "mixed" and envs.FD_ENABLE_INTERNAL_ADAPTER: envs.FD_ZMQ_RECV_REQUEST_SERVER_PORT = envs.FD_ZMQ_RECV_REQUEST_SERVER_PORTS.split(",")[0] envs.FD_ZMQ_SEND_RESPONSE_SERVER_PORT = envs.FD_ZMQ_SEND_RESPONSE_SERVER_PORTS.split(",")[0] diff --git a/fastdeploy/engine/sched/resource_manager_v1.py b/fastdeploy/engine/sched/resource_manager_v1.py index c07a7c518b0..83396ddf753 100644 --- a/fastdeploy/engine/sched/resource_manager_v1.py +++ b/fastdeploy/engine/sched/resource_manager_v1.py @@ -1338,8 +1338,11 @@ def clear_data(self): def update_metrics(self): # Update metrics num_tasks = sum([1 if task else 0 for task in self.tasks_list]) - num_blocks_used_by_tasks = sum([len(task.block_tables) if task else 0 for task in self.tasks_list]) - main_process_metrics.available_gpu_block_num.set(self.total_block_number() - num_blocks_used_by_tasks) + blocks_used_by_tasks = set() + for task in self.tasks_list: + if task is not None: + blocks_used_by_tasks.update(task.block_tables) + main_process_metrics.available_gpu_block_num.set(self.total_block_number() - len(blocks_used_by_tasks)) main_process_metrics.batch_size.set(self.max_num_seqs - self.available_batch()) main_process_metrics.gpu_cache_usage_perc.set(self.get_gpu_cache_usage_perc()) main_process_metrics.num_requests_running.set(len(self.running)) diff --git a/tests/cache_manager/test_cache_transfer_manager.py b/tests/cache_manager/test_cache_transfer_manager.py index 92fcfe617a3..b27fac3bd63 100644 --- a/tests/cache_manager/test_cache_transfer_manager.py +++ b/tests/cache_manager/test_cache_transfer_manager.py @@ -29,6 +29,7 @@ class Args: mp_num = 1 device_id = 0 speculative_config = {} + model_id = "test_model" ipc_suffix = "test_ipc_suffix" cache_queue_port = 9999 pod_ip = "127.0.0.1" diff --git a/tests/cache_manager/test_prefix_cache_manager.py b/tests/cache_manager/test_prefix_cache_manager.py index 52d1c01040e..1045f860a0c 100644 --- a/tests/cache_manager/test_prefix_cache_manager.py +++ b/tests/cache_manager/test_prefix_cache_manager.py @@ -185,6 +185,7 @@ def _create_manager( swap_space=4, ) model_config = SimpleNamespace( + model="test_model", num_attention_heads=1, num_key_value_heads=1, head_dim=1, diff --git a/tests/engine/test_common_engine.py b/tests/engine/test_common_engine.py index 934306f92eb..efa3af9a066 100644 --- a/tests/engine/test_common_engine.py +++ b/tests/engine/test_common_engine.py @@ -332,8 +332,6 @@ def fake_init_signals(): self.assertFalse(ok) # cache manager started before workers (lines 184-185) self.assertTrue(started_cache.get("called", False)) - # launched_cache_manager_signal set (line 221) - self.assertEqual(int(eng.launched_cache_manager_signal.value[0]), 1) # avoid atexit finalizer if hasattr(eng, "_finalizer"): try: