diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 0000000..b5e50b4 --- /dev/null +++ b/.gitattributes @@ -0,0 +1 @@ +text=auto eol=lf diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..2207640 --- /dev/null +++ b/.gitignore @@ -0,0 +1,9 @@ +.ruff_cache +_version.py +build +*.so +*.pyc +.pytest_cache + +# Ascend Specific +fusion_result.json \ No newline at end of file diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 0000000..27c1823 --- /dev/null +++ b/.gitmodules @@ -0,0 +1,3 @@ +[submodule "third_party/kvcache-ops"] + path = third_party/kvcache-ops + url = https://gitee.com/openeuler/kvcache-ops.git diff --git a/CMakeLists.txt b/CMakeLists.txt new file mode 100644 index 0000000..54db16c --- /dev/null +++ b/CMakeLists.txt @@ -0,0 +1,60 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2020. All rights reserved. + +# CMake lowest version requirement +cmake_minimum_required(VERSION 3.16.0) +# project information +project(c_ops) + +set(CMAKE_CXX_STANDARD 17) +set(LMC_INSTALL_PATH "${CMAKE_INSTALL_PREFIX}") +add_compile_definitions(_GLIBCXX_USE_CXX11_ABI=0) + +set(SOC_VERSION ${SOC_VERSION}) +set(ARCH ${ARCH}) + +if (NOT CMAKE_BUILD_TYPE) + set(CMAKE_BUILD_TYPE "Release" CACHE STRINGS "Build type Release/Debug (default Release)" FORCE) +endif() + +if(CMAKE_SYSTEM_PROCESSOR STREQUAL "aarch64") + set(ARCH_SUBDIR "aarch64-linux") +elseif(CMAKE_SYSTEM_PROCESSOR STREQUAL "x86_64") + set(ARCH_SUBDIR "x86_64-linux") +else() + message(FATAL_ERROR "Unsupported architecture: ${CMAKE_SYSTEM_PROCESSOR}") +endif() + +add_subdirectory(third_party/kvcache-ops) +add_subdirectory(csrc) + + +set(TORCH_LIBS_DIR "${TORCH_PATH}/lib") + +target_link_options(c_ops PRIVATE + "-Wl,-rpath,$ORIGIN:$ORIGIN/lib" + "-Wl,-rpath,${LMC_INSTALL_PATH}" +) + +target_link_directories( + c_ops + PRIVATE + ${TORCH_LIBS_DIR} + ${TORCH_NPU_PATH}/lib/ + ${ASCEND_CANN_PACKAGE_PATH}/lib64 + ${ASCEND_CANN_PACKAGE_PATH}/${ARCH_SUBDIR}/devlib +) + +target_link_libraries( + c_ops + PUBLIC + ${TORCH_LIBRARIES} + libtorch_npu.so + cache_kernels + ascendcl + platform + ascend_hal + tiling_api +) + + +install(TARGETS c_ops cache_kernels DESTINATION ${LMC_INSTALL_PATH}) \ No newline at end of file diff --git a/README.md b/README.md index c207a61..0b5fb10 100644 --- a/README.md +++ b/README.md @@ -14,4 +14,121 @@ -------------------------------------------------------------------------------- +## Overview +LMCache-Ascend is a community maintained plugin for running LMCache on the Ascend NPU. + + +## Prerequisites + +To use LMCache-Ascend on the NPU hardware, please make sure the following prerequisites are satisfied. + +- Hardware: Atlas 800I A2 Inference series. The rest of the series like A3 Inference/Training and 300I Duo are experimental. +- OS: Linux-based. +- Software: + - **Python**: >= 3.10, <= 3.11 + - **CANN Toolkit**: >= 8.2rc1 + - **Ascend Driver**: >= 24.1 + - **PyTorch**: == 2.5.1, **Torch-npu**: == 2.5.1.post1.dev20250619 + - **vLLM**: v0.9.2 & **vLLM-Ascend**: v0.9.2rc1 + +## Getting Started + +### Clone LMCache-Ascend Repo + +Our repo contains a kvcache ops submodule for ease of maintainence, therefore we recommend cloning the repo with submodules. + +```bash +cd /workspace +git clone --recurse-submodules https://github.com/LMCache/LMCache-Ascend.git +``` + +### Docker + +```bash +cd /workspace/LMCache-Ascend +docker build -f docker/Dockerfile.a2.openEuler -t lmcache-ascend:v0.3.3-vllm-ascend-v0.9.2rc1-910b-cann-8.2rc1-py3.11-openeuler-22.03 . +``` + +Once that is built, run it with the following cmd +```bash +DEVICE_LIST="0,1,2,3,4,5,6,7" +docker run -it \ + --privileged \ + --cap-add=SYS_PTRACE \ + --net=host \ + --name lmcache-ascend-dev \ + --rm \ + -e ASCEND_VISIBLE_DEVICES=${DEVICE_LIST} \ + -e ASCEND_RT_VISIBLE_DEVICES=${DEVICE_LIST} \ + -e ASCEND_TOTAL_MEMORY_GB=32 \ + -e VLLM_TARGET_DEVICE=npu \ + -v /usr/local/Ascend/driver:/usr/local/Ascend/driver \ + -v /usr/local/sbin/npu-smi:/usr/local/sbin/npu-smi \ + -v /etc/localtime:/etc/localtime \ + -v /usr/local/dcmi:/etc/local/dcmi \ + -v /var/log/npu:/var/log/npu \ + -v /sys/fs/cgroup:/sys/fs/cgroup:ro \ + -v /dev/davinci_manager:/dev/davinci_manager \ + -v /dev/devmm_svm:/dev/devmm_svm \ + -v /etc/ascend_install.info:/etc/ascend_install.info \ + -v /etc/hccn.conf:/etc/hccn.conf \ + lmcache-ascend:v0.3.3-vllm-ascend-v0.9.2rc1-910b-cann-8.2rc1-py3.11-openeuler-22.03 \ + /bin/bash +``` + +### Manual Installation + +Assuming your working directory is ```/workspace```. + +1. Clone and Install vLLM Repo +```bash +VLLM_REPO=https://github.com/vllm-project/vllm.git +VLLM_TAG=v0.9.2 +git clone --depth 1 $VLLM_REPO --branch $VLLM_TAG /workspace/vllm +# NOTE: There is an Ascend Triton but we don't currently support it properly. +VLLM_TARGET_DEVICE="empty" python3 -m pip install -e /workspace/vllm/ --extra-index https://download.pytorch.org/whl/cpu/ && \ + python3 -m pip uninstall -y triton +``` + +2. Clone and Install vLLM Ascend Repo +```bash +source /usr/local/Ascend/ascend-toolkit/set_env.sh +source /usr/local/Ascend/nnal/atb/set_env.sh + +VLLM_ASCEND_REPO=https://github.com/vllm-project/vllm-ascend.git +VLLM_ASCEND_TAG=v0.9.2rc1 +git clone --depth 1 $VLLM_ASCEND_REPO --branch $VLLM_ASCEND_TAG /workspace/vllm-ascend +# apply patch to v0.9.2rc1 +cd /workspace/vllm-ascend && \ + git apply -p1 /workspace/LMCache-Ascend/docker/kv-connector-v1.diff + +export PIP_EXTRA_INDEX_URL=https://mirrors.huaweicloud.com/ascend/repos/pypi + +export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/Ascend/ascend-toolkit/latest/`uname -i`-linux/devlib && \ +python3 -m pip install -v -e /workspace/vllm-ascend/ --extra-index https://download.pytorch.org/whl/cpu/ +``` + +3. Clone and Install LMCache Repo + +```bash +LMCACHE_REPO=https://github.com/LMCache/LMCache.git +LMCACHE_TAG=v0.3.3 +git clone --depth 1 $LMCACHE_REPO --branch $LMCACHE_TAG /workspace/LMCache +# our build is based on arm64 +sed -i "s/^infinistore$/infinistore; platform_machine == 'x86_64'/" /workspace/LMCache/requirements/common.txt +export NO_CUDA_EXT=1 && python3 -m pip install -v -e /workspace/LMCache +``` + +4. Install LMCache-Ascend Repo + +```bash +cd /workspace/LMCache-Ascend +python3 -m pip install -v --no-build-isolation -e . +``` + +## FAQ + +1. Why do I have HostRegisterError ? + - If you encounter the Host Register Error within a container environment, please make sure you add the IPC_LOCK capabilities. + - Otherwise, please check your driver version is >= 24.0 diff --git a/csrc/CMakeLists.txt b/csrc/CMakeLists.txt new file mode 100644 index 0000000..22c4487 --- /dev/null +++ b/csrc/CMakeLists.txt @@ -0,0 +1,44 @@ +include(utils.cmake) +append_cmake_prefix_path("torch" "torch.utils.cmake_prefix_path") + +find_package(Python3 COMPONENTS Interpreter Development REQUIRED) +set(PYTHON_SUPPORTED_VERSIONS "3.10" "3.11") +find_package(pybind11 REQUIRED) + +message("TORCH_NPU_PATH is ${TORCH_NPU_PATH}") + +file(GLOB SRC_FILES +${CMAKE_CURRENT_SOURCE_DIR}/*.cpp) + +find_package(Torch REQUIRED) + +include_directories( + ${CMAKE_CURRENT_SOURCE_DIR} + ${pybind11_INCLUDE_DIRS} + ${PYTHON_INCLUDE_PATH} + ${TORCH_INCLUDE_DIRS} + ${TORCH_NPU_PATH}/include + ${ASCEND_CANN_PACKAGE_PATH}/include + ${ASCEND_CANN_PACKAGE_PATH}/aarch64-linux/ascendc/include + ${ASCEND_CANN_PACKAGE_PATH}/aarch64-linux/include/experiment/platform + ${ASCEND_CANN_PACKAGE_PATH}/aarch64-linux/include/experiment/ascend_hal + ${ASCEND_CANN_PACKAGE_PATH}/x86_64-linux/include/experiment/platform + ${ASCEND_CANN_PACKAGE_PATH}/x86_64-linux/include/experiment/ascend_hal +) + + +set( + INCLUDES + ${TORCH_INCLUDE_DIRS} + ${TORCH_NPU_PATH}/include + ${ASCEND_CANN_PACKAGE_PATH}/include + ${ASCEND_CANN_PACKAGE_PATH}/aarch64-linux/ascendc/include + ${ASCEND_CANN_PACKAGE_PATH}/aarch64-linux/include/experiment/platform + ${ASCEND_CANN_PACKAGE_PATH}/aarch64-linux/include/experiment/ascend_hal +) + +set(PYMODULE_FILES + ${SRC_FILES} +) + +pybind11_add_module(c_ops ${PYMODULE_FILES}) diff --git a/csrc/cachegen_kernels.cpp b/csrc/cachegen_kernels.cpp new file mode 100644 index 0000000..fdf865f --- /dev/null +++ b/csrc/cachegen_kernels.cpp @@ -0,0 +1,32 @@ +#include "cachegen_kernels.h" +#include +#include + +namespace py = pybind11; + +void encode_cuda_new(const at::Tensor& cdf, const at::Tensor& input_sym, + at::Tensor& output_buffer, at::Tensor& output_lengths) { + // TODO: + PyErr_SetString(PyExc_NotImplementedError, "Please contact LMCache Ascend."); + throw py::error_already_set(); +}; + +void decode_cuda_new(const at::Tensor& cdf, const at::Tensor& bytestreams, + const at::Tensor& lengths, at::Tensor& output) { + // TODO: + PyErr_SetString(PyExc_NotImplementedError, "Please contact LMCache Ascend."); + throw py::error_already_set(); +}; + +void decode_cuda_prefsum(const at::Tensor& cdf, const at::Tensor& bytestreams, + const at::Tensor& lengths, at::Tensor& output) { + // TODO: + PyErr_SetString(PyExc_NotImplementedError, "Please contact LMCache Ascend."); + throw py::error_already_set(); +}; + +at::Tensor calculate_cdf(const at::Tensor& input, const int max_bins) { + // TODO: + PyErr_SetString(PyExc_NotImplementedError, "Please contact LMCache Ascend."); + throw py::error_already_set(); +}; \ No newline at end of file diff --git a/csrc/cachegen_kernels.h b/csrc/cachegen_kernels.h new file mode 100644 index 0000000..d1a1701 --- /dev/null +++ b/csrc/cachegen_kernels.h @@ -0,0 +1,16 @@ +#pragma once +#include +#include +#include +#include + +void encode_cuda_new(const at::Tensor& cdf, const at::Tensor& input_sym, + at::Tensor& output_buffer, at::Tensor& output_lengths); + +void decode_cuda_new(const at::Tensor& cdf, const at::Tensor& bytestreams, + const at::Tensor& lengths, at::Tensor& output); + +void decode_cuda_prefsum(const at::Tensor& cdf, const at::Tensor& bytestreams, + const at::Tensor& lengths, at::Tensor& output); + +at::Tensor calculate_cdf(const at::Tensor& input, const int max_bins); \ No newline at end of file diff --git a/csrc/managed_mem.cpp b/csrc/managed_mem.cpp new file mode 100644 index 0000000..4c79f54 --- /dev/null +++ b/csrc/managed_mem.cpp @@ -0,0 +1,326 @@ +#include "managed_mem.h" +#include +// Only required for old driver version (look at registerHostPtr) +#ifdef PROF_ERROR + // You can add a pragma message to see this in your build log if you want: + // #pragma message("Undefining PROF_ERROR from ascend_hal.h before NPU headers") + #undef PROF_ERROR +#endif + +#include +#include +#include "driver/ascend_hal_define.h" +#include "driver/ascend_hal.h" +#include +#include "torch/torch.h" +#include "torch/extension.h" + +namespace lmc { +constexpr int32_t PROT_FLAGS = static_cast(PROT_READ) | static_cast(PROT_WRITE); +constexpr int32_t MAP_FLAGS = static_cast(MAP_PRIVATE) | static_cast(MAP_ANONYMOUS) | static_cast(MAP_POPULATE); + +// Signatures for internal helper functions + +// Get the version of the NPU driver as a string +std::string get_driver_version(); +// Checks whether the major version of the NPU is greater or equal 25 to support aclrtHostRegister +bool is_version_at_least_25(const std::string& version_str); +// Gets the current device offsetting on ASCEND_RT_VISIBLE_DEVICES when needed +int get_device(); +// Uregisters the malloced hostPtr +void unregisterPtr(void* ptr); +// Swaps the host memory allocated to a tensor with the given hostPtr +void swap_tensor_ptr(void* hostPtr, torch::Tensor& original_tensor); + +// Class implementations + +HostRegisteredMemoryManager::HostRegisteredMemoryManager(){ +}; + +HostRegisteredMemoryManager::~HostRegisteredMemoryManager() { + this->unregisterAll(); +}; + +void HostRegisteredMemoryManager::unregisterAll(){ + const std::unique_lock guard(this->mux); + + // Iterate through each key-value pair in the map. + for (const auto& pair : this->allocatedMap) { + void* hostPtr = pair.first; + aclrtHostUnregister(hostPtr); + } + + // After unregistering all pointers, clear the map completely. + this->allocatedMap.clear(); +}; + +// Register a pointer through high level APIs (aclrt) return devPtr +// Returns the created RegisteredMemoryRecord +RegisteredMemoryRecord HostRegisteredMemoryManager::registerHostPtr(void* hostPtr, size_t bufferSize) { // torch::Tensor& tensor){ + TORCH_CHECK(!(hostPtr == nullptr || bufferSize == 0), "Error: hostPtr cannot be null and bufferSize must be greater than 0."); + const std::unique_lock guard(this->mux); + + // Check if the host pointer is already registered + if (this->allocatedMap.count(hostPtr)) { + return this->allocatedMap[hostPtr]; + } + + void* devPtr; + aclError err = aclrtHostRegister(hostPtr, static_cast(bufferSize), + ACL_HOST_REGISTER_MAPPED, (void**)&devPtr); + TORCH_CHECK(err == 0, "Unable to host register the host ptr: " + std::to_string(err)); + + this->allocatedMap.emplace(hostPtr, RegisteredMemoryRecord{reinterpret_cast(hostPtr), + reinterpret_cast(devPtr), bufferSize}); + + return this->allocatedMap[hostPtr]; +}; + +// Register a pointer through low level APIs (HAL). Allocates a new pinned host memory +// This should be used for driver versions, where cannot rely on aclrtHostRegister() +// Returns the created RegisteredMemoryRecord +RegisteredMemoryRecord HostRegisteredMemoryManager::halRegisterHostPtr(size_t bufferSize){ + // We allocate a new chunk of memory, register it, and replace the tensor. + // Essentially, the halHostRegister function requires a ptr given by mmap. + TORCH_CHECK((bufferSize >= 0), "Error: bufferSize must be greater than 0."); + const std::unique_lock guard(this->mux); + + void* devPtr; + int device = get_device(); + void* hostPtr; + // Allocate and register + hostPtr = mmap(nullptr, bufferSize, PROT_FLAGS, MAP_FLAGS, -1, 0); + TORCH_CHECK(hostPtr != MAP_FAILED, "Unable to alloc memory with mmap."); + auto ret = madvise(reinterpret_cast(hostPtr), bufferSize, MADV_HUGEPAGE); + auto drvRet = halHostRegister((void*)hostPtr, static_cast(bufferSize), + HOST_MEM_MAP_DEV_PCIE_TH, (UINT32)device, (void**)&devPtr); + TORCH_CHECK(drvRet == 0, "Unable to register host memory with hal: " + std::to_string(drvRet)) + + // Lock the memory and fail if impossible to lock + auto lockErr = mlock(reinterpret_cast(hostPtr), bufferSize); + if (lockErr == -1) { + // This can happen in non-privileged mode or not enough rlimit, + // let's not proceed since we wanted to guarantee pinned + // because we already alloced, let's free + auto ret = halHostUnregisterEx(reinterpret_cast(hostPtr), + static_cast(device), HOST_MEM_MAP_DEV_PCIE_TH); + TORCH_CHECK(ret==0, "Unable to pin host memory, unable to unregister. Error code: " + std::to_string(ret)) + auto mret = munmap(reinterpret_cast(hostPtr), bufferSize); + TORCH_CHECK(false, "Unable to pin host memory with error code: " + std::to_string(lockErr)) + } + + this->allocatedMap.emplace(hostPtr, RegisteredMemoryRecord{reinterpret_cast(hostPtr), + reinterpret_cast(devPtr), bufferSize}); + + return this->allocatedMap[hostPtr]; +}; + +void HostRegisteredMemoryManager::unregisterMemory(void* hostPtr) { + TORCH_CHECK(hostPtr != nullptr, "Error: hostPtr cannot be null."); + + // we don't actually mind if it doesn't unregister, + // at context destroy it should be unregister anyway. + const std::unique_lock guard(this->mux); + aclError err = aclrtHostUnregister(hostPtr); + this->allocatedMap.erase(hostPtr); +}; + +/* +* For now we only do a linear search as we probably won't have a long list of ptrs +* we go through each record and check whether we are in range, if so +* we calculate the offset from the host ptr and apply to the device ptr +* finally we return the device ptr. +*/ +void* HostRegisteredMemoryManager::getDevicePtr(void* hostPtr) { + if (hostPtr == nullptr) { + return nullptr; + } + const std::shared_lock guard(this->mux); + + const uintptr_t hostAddrPtr = reinterpret_cast(hostPtr); + + for (const auto& pair: this->allocatedMap) { + const RegisteredMemoryRecord& record = pair.second; + + if (hostAddrPtr >= record.ptr && hostAddrPtr < (record.ptr + record.buffSize)) { + const size_t offset = hostAddrPtr - record.ptr; + + const uintptr_t deviceAddrPtr = record.devptr + offset; + + return reinterpret_cast(deviceAddrPtr); + } + } + + return nullptr; +}; + + +size_t HostRegisteredMemoryManager::getRecordSize(void* hostPtr){ + if (hostPtr == nullptr) { + return 0; + } + const std::shared_lock guard(this->mux); + + const uintptr_t hostAddrPtr = reinterpret_cast(hostPtr); + + for (const auto& pair: this->allocatedMap) { + const RegisteredMemoryRecord& record = pair.second; + + if (hostAddrPtr >= record.ptr && hostAddrPtr < (record.ptr + record.buffSize)) { + return record.buffSize; + } + } + return 0; +}; + +std::string get_driver_version() { + void* handle = nullptr; + int (*dsmi_get_version)(int, char*, unsigned int, unsigned int*) = nullptr; + std::string result; + + handle = dlopen("libdrvdsmi_host.so", RTLD_LAZY); + if (!handle) { + TORCH_CHECK(false, std::string("Error opening libdrvdsmi_host.so: ") + dlerror() ); + return result; + } + dlerror(); + + // Load the function + *(void**) (&dsmi_get_version) = dlsym(handle, "dsmi_get_version"); + const char* dlsym_error = dlerror(); + if (dlsym_error) { + dlclose(handle); + TORCH_CHECK(false, std::string("Error loading dsmi_get_version: ") + dlsym_error); + return result; + } + + // Call the function + int device_id = c10_npu::getCurrentNPUStream().device_index(); + const unsigned int buffer_size = 256; + std::vector version_buffer(buffer_size); + unsigned int ret_len = 0; + int ret = dsmi_get_version(device_id, version_buffer.data(), buffer_size, &ret_len); + if (ret == 0) { + if (ret_len > 0 && ret_len <= buffer_size) { + version_buffer[ret_len] = '\0'; // Ensure null-termination + result = version_buffer.data(); + } else { + TORCH_CHECK(false, "Error: Invalid length returned: " + std::to_string(ret_len)); + } + } else { + TORCH_CHECK(false, "Error: dsmi_get_version returned " + std::to_string(ret)); + } + + dlclose(handle); + + return result; +} + +// To be on the safe side, returns false in case of uncertainties +bool is_version_at_least_25(const std::string& version_str) { + if (version_str.empty()) { + return false; + } + + size_t num_end = 0; + long major_version = 0; + + try { + major_version = std::stol(version_str, &num_end); + } catch (const std::invalid_argument&) { + // No valid number at start + return false; + } catch (const std::out_of_range&) { + // Should never happen, here for robustness + return false; + } + return major_version >= 25; +} + +int get_device(){ + int device = c10_npu::getCurrentNPUStream().device_index(); + const char* env_visible_devices_p = std::getenv("ASCEND_RT_VISIBLE_DEVICES"); + // If we are using a custom list of visible devices, the index refers to that + if (env_visible_devices_p != nullptr) { + std::string env_visible_devices = env_visible_devices_p; + std::vector list_visible_devices; + std::stringstream ss(env_visible_devices); + std::string item; + while (std::getline(ss, item, ',')) { + list_visible_devices.push_back(std::stoi(item)); + } + std::sort(list_visible_devices.begin(), list_visible_devices.end()); + // Here two cases are possible: + // 1. no hccl, we just use current_device, even though we have specify the ASCEND_RT_VISIBLE_DEVICES + // 2. hccl, and we use current_device that seems to be correct + // for case 2, since the current_device would have been correct anyway, obtaining from the list would be fine. + // for case 1, we have shifted the device to the RT_VISIBLE_DEVICES, so it should be corrected. + device = list_visible_devices[device]; + } + return device; +} + +void unregisterPtr(void* ptr) { + if (ptr){ + int device = get_device(); + auto& hmm = HostRegisteredMemoryManager::GetInstance(); + size_t bufferSize = hmm.getRecordSize(ptr); + auto ret = halHostUnregisterEx(reinterpret_cast(ptr), + static_cast(device), HOST_MEM_MAP_DEV_PCIE_TH); + if (ret != 0) { + std::cout << "Unable to hal host unregister: "<< ret << std::endl; + } + auto mret = munmap(reinterpret_cast(ptr), bufferSize); + if (mret != 0) { + std::cout << "Unable to unmap memory: "<< ret << std::endl; + } + } +} + + +void swap_tensor_ptr(void* hostPtr, torch::Tensor& original_tensor){ + torch::TensorOptions tensorOpsCpu = torch::TensorOptions() + .dtype(original_tensor.dtype()) + .device(original_tensor.device()) + .pinned_memory(true); + int64_t numel = static_cast(original_tensor.nbytes()); + std::vector dims = {numel}; + torch::Tensor new_tensor_from_myptr = torch::from_blob( + hostPtr, dims, unregisterPtr, tensorOpsCpu); + + original_tensor.set_(new_tensor_from_myptr.storage(), original_tensor.storage_offset(), + original_tensor.sizes(), original_tensor.strides()); +} + +} // namespace lmc + + +void* register_memory(torch::Tensor& tensor) { + torch::Device device = tensor.device(); + if (!device.is_cpu() || !tensor.is_pinned()) { + TORCH_CHECK(false, "Invalid device. Device must be CPU and tensor must be pinned."); + } + auto& hmm = lmc::HostRegisteredMemoryManager::GetInstance(); + size_t tensorSize = tensor.nbytes(); + std::string verString = lmc::get_driver_version(); + if (lmc::is_version_at_least_25(verString)) { // New driver version, supports aclrtHostRegister() + void* hostPtr = static_cast(tensor.data_ptr()); + return (void*) hmm.registerHostPtr(hostPtr, tensorSize).devptr; + } else { // Old driver version, does not support aclrtHostRegister(), we have to use HAL. + // We ask for a new registerd memory and substitute with the previously allocated. + lmc::RegisteredMemoryRecord record = hmm.halRegisterHostPtr(tensorSize); + lmc::swap_tensor_ptr((void*) record.ptr, tensor); + return (void*) record.devptr; + } +}; + +void unregister_memory(torch::Tensor& tensor) { + void* hostPtr = static_cast(tensor.data_ptr()); + auto& hmm = lmc::HostRegisteredMemoryManager::GetInstance(); + hmm.unregisterMemory(hostPtr); +}; + +void* get_device_ptr(void* ptr) { + auto& hmm = lmc::HostRegisteredMemoryManager::GetInstance(); + return hmm.getDevicePtr(ptr); +}; diff --git a/csrc/managed_mem.h b/csrc/managed_mem.h new file mode 100644 index 0000000..ae42364 --- /dev/null +++ b/csrc/managed_mem.h @@ -0,0 +1,69 @@ +#pragma once +#include +#include +#include +#include + +namespace lmc { + +struct RegisteredMemoryRecord { + uintptr_t ptr; + uintptr_t devptr; + size_t buffSize; +}; + +/* +* We are not responsible for acl init and ctx initialization, +* we assume the user responsible for ctx initialization +*/ +class HostRegisteredMemoryManager { +private: + HostRegisteredMemoryManager(); + + // Delete copy constructor and assignment operator + HostRegisteredMemoryManager(const HostRegisteredMemoryManager&) = delete; + HostRegisteredMemoryManager& operator=(const HostRegisteredMemoryManager&) = delete; + HostRegisteredMemoryManager(HostRegisteredMemoryManager&&) = delete; + HostRegisteredMemoryManager& operator=(HostRegisteredMemoryManager&&) = delete; + + std::map allocatedMap; + mutable std::shared_mutex mux; + +public: + static HostRegisteredMemoryManager& GetInstance() + { + static HostRegisteredMemoryManager instance; + return instance; + } + ~HostRegisteredMemoryManager(); + + // Register a pointer through high level APIs (aclrt) return devPtr + // Returns an already existing RegisteredMemoryRecord or the newly created one + // Inputs: + // -hostPtr: host pointer of the allocated memory area to register on device + // -bufferSize: size of the allocated memory area to register on device + RegisteredMemoryRecord registerHostPtr(void* hostPtr, size_t bufferSize); //torch::Tensor& tensor); // + // Register a pointer through low level APIs (hal) + // This should be used for driver versions, where cannot rely on aclrtHostRegister() + // Returns the created RegisteredMemoryRecord + // Inputs: + // -bufferSize: size of the allocated memory area to register on device + RegisteredMemoryRecord halRegisterHostPtr(size_t bufferSize); + void unregisterMemory(void* hostPtr); + void* getDevicePtr(void* hostPtr); + size_t getRecordSize(void* hostPtr); + void unregisterAll(); +}; +} // namespace lmc + +// Register a tensor on the current device +// Inputs: +// -tensor: The tensor to register on the device +// Returns the device ptr for that tensor +void* register_memory(torch::Tensor& tensor); +// Reverse of register +// Inputs: +// -tensor: The tensor to register on the device +void unregister_memory(torch::Tensor& tensor); +// Takes in input a host pointer, returns the corresponding device pointer +void* get_device_ptr(void* ptr); diff --git a/csrc/mem_kernels.cpp b/csrc/mem_kernels.cpp new file mode 100644 index 0000000..2465eb7 --- /dev/null +++ b/csrc/mem_kernels.cpp @@ -0,0 +1,245 @@ +#include "mem_kernels.h" +#include +#include +#include +#include +#include "utils.h" +#include "tiling/platform/platform_ascendc.h" +#include +#include + +template +T* get_kernel_ptr(TENSOR_TYPE& tensor) { + torch::Device device = tensor.device(); + // NPU should be using PrivateUse1 + if (device.is_privateuseone() || device.is_cuda()) { + return static_cast(tensor.data_ptr()); + } else if (device.is_cpu()) { + // find device ptr based on the host pinned ptr + // because acl does not currently support HostGetDevicePointer API + void* devPtr = get_device_ptr(tensor.data_ptr()); + TORCH_CHECK(devPtr != nullptr, "Unable to retrieve device ptr, is this a host registered pointer ?"); + return reinterpret_cast(devPtr); + } else { + TORCH_CHECK(false, "Invalid device. Device must be ascend (PrivateUseOne) or pinned cpu."); + } +} + +/** + * Quickly offload KV cache from vLLM paged memory to the offloading buffer + * Processes all the layers at the same time + * + * Each layer in vLLM's KV buffer has a shape of + * [2, PAGE_BUFFER_SIZE, num_heads*head_size] + * + * Each AIV Core processes the copy for a token + * + * Therefore: + * AIV Core - token + * + * The function does: + * slot_id = slot_mapping[tokenId] + * ptrs[mem_offset(kv, layer, tokenId, hiddenDims)] = key_value[mem_offset(kv, layer, pages, pageSize, slot_id, hiddenDims)] + * + * Param: + * - direction: false means LMCache to PagedBuffer, true means PagedBuffer to + * LMCache + */ +void multi_layer_kv_transfer(torch::Tensor& key_value, // [kv, num_layer, num_tokens, hidden] + const torch::Tensor& key_value_ptrs, // [num_layers] + const torch::Tensor& slot_mapping, // [num_tokens] + const torch::Device& paged_memory_device, + const int page_buffer_size, const bool direction, + const bool use_mla) { + uint8_t* key_value_ptr = get_kernel_ptr(key_value); + // it is actually a uint8_t**. we will reinterpret it inside the kernel + uint8_t* page_buffer_ptrs = get_kernel_ptr(key_value_ptrs); + uint8_t* slot_mapping_ptr = get_kernel_ptr(slot_mapping); + + int num_layers = key_value.size(1); + int num_tokens = slot_mapping.size(0); + int hidden_dims = key_value.size(-1); + int kv_size = 2; + if (use_mla) { + kv_size = 1; + } + + const c10::OptionalDeviceGuard device_guard(paged_memory_device); + // we require the kv ptr list to be on the device too + const c10::OptionalDeviceGuard kv_device_guard(device_of(key_value_ptrs)); + + const aclrtStream stream = c10_npu::getCurrentNPUStream().stream(); + at::ScalarType scalar_type = key_value.scalar_type(); + at::ScalarType slot_type = slot_mapping.scalar_type(); + const char* socName = aclrtGetSocName(); + + at_npu::native::OpCommand cmd; + cmd.Name("multi_layer_kv_transfer_kernel"); + cmd.SetCustomHandler([scalar_type, slot_type, socName, stream, page_buffer_ptrs, key_value_ptr, + slot_mapping_ptr, hidden_dims, kv_size, num_layers, page_buffer_size, + num_tokens, direction]()->int{ + auto slot_num = vllm_ascend::get_dtype_from_torch(slot_type); + auto dtype_num = vllm_ascend::get_dtype_from_torch(scalar_type); + auto ascendcPlatform = platform_ascendc::PlatformAscendCManager::GetInstance(socName); + uint32_t aiv_num = ascendcPlatform->GetCoreNumAiv(); + kvcache_ops::multi_layer_kv_transfer_kernel(dtype_num, slot_num, aiv_num, stream, page_buffer_ptrs, key_value_ptr, + slot_mapping_ptr, hidden_dims, kv_size, num_layers, page_buffer_size, + num_tokens, direction); + return 0; + }); + cmd.Run(); + return ; +}; + + +void multi_layer_kv_transfer_unilateral(torch::Tensor& key_value, + const torch::Tensor& key_ptrs, + const torch::Tensor& value_ptrs, + const torch::Tensor& slot_mapping, + const torch::Device& paged_memory_device, + const int page_buffer_size, + const bool direction){ + // TODO: + PyErr_SetString(PyExc_NotImplementedError, "Please contact LMCache Ascend."); + throw py::error_already_set(); +}; + + +void single_layer_kv_transfer(torch::Tensor& lmc_key_value_cache, // [num_tokens, 2, num_heads*head_size] + // or + // [2, num_tokens, num_heads*head_size] + torch::Tensor& vllm_key_cache, // [num_blocks, block_size, num_heads, head_size] + torch::Tensor& vllm_value_cache, // [....] + torch::Tensor& slot_mapping, // [num_tokens] + const bool direction, // false: LMCache to PagedBuffer, true: PagedBuffer to LMCache + const bool token_major // true: lmc_key_value_cache is [num_tokens, 2, num_heads*head_size] + // false: otherwise +) { + uint8_t *lmc_key_value_cache_ptr = get_kernel_ptr(lmc_key_value_cache); + uint8_t *vllm_key_cache_ptr = get_kernel_ptr(vllm_key_cache); + uint8_t *vllm_value_cache_ptr = get_kernel_ptr(vllm_value_cache); + uint8_t *slot_mapping_ptr = get_kernel_ptr(slot_mapping); + + int num_tokens = slot_mapping.size(0); + int hidden_dims = lmc_key_value_cache.size(-1); + + const c10::OptionalDeviceGuard device_guard(device_of(vllm_key_cache)); + const c10::OptionalDeviceGuard slot_device_guard(device_of(slot_mapping)); + const aclrtStream stream = c10_npu::getCurrentNPUStream().stream(); + + at::ScalarType scalar_type = vllm_key_cache.scalar_type(); + at::ScalarType slot_type = slot_mapping.scalar_type(); + + const char* socName = aclrtGetSocName(); + + at_npu::native::OpCommand cmd; + cmd.Name("single_layer_kv_transfer_kernel"); + cmd.SetCustomHandler([scalar_type, slot_type, socName, stream, lmc_key_value_cache_ptr, + vllm_key_cache_ptr, vllm_value_cache_ptr, slot_mapping_ptr, + hidden_dims, num_tokens, direction, token_major]() -> int { + auto slot_num = vllm_ascend::get_dtype_from_torch(slot_type); + auto dtype_num = vllm_ascend::get_dtype_from_torch(scalar_type); + auto ascendcPlatform = platform_ascendc::PlatformAscendCManager::GetInstance(socName); + uint32_t aiv_num = ascendcPlatform->GetCoreNumAiv(); + // TODO: We will add the isMLA argument once the signature have support for the MLA. + kvcache_ops::single_layer_kv_transfer_kernel(dtype_num, slot_num, aiv_num, stream, lmc_key_value_cache_ptr, + vllm_key_cache_ptr, vllm_value_cache_ptr, slot_mapping_ptr, + hidden_dims, num_tokens, direction, token_major, false); + return 0; + }); + cmd.Run(); + return ; +}; + +void load_and_reshape_flash( + torch::Tensor& key_value, // [2, num_layer, num_tokens, num_heads*head_size] + // must be one gpu / pinned cpu + torch::Tensor& key_cache, // [num_blocks, block_size, num_heads, head_size] + torch::Tensor& value_cache, // [num_blocks, block_size, num_heads, head_size] + torch::Tensor& slot_mapping, // [num_tokens], + const int layer_idx) { + + uint8_t* key_value_ptr = get_kernel_ptr(key_value); + uint8_t* key_cache_ptr = get_kernel_ptr(key_cache); + uint8_t* value_cache_ptr = get_kernel_ptr(value_cache); + + uint8_t* slot_mapping_ptr = get_kernel_ptr(slot_mapping); + + int num_tokens = slot_mapping.size(0); + int num_layers = key_value.size(1); + int block_size = key_cache.size(1); + int num_blocks = key_cache.size(0); + int hidden_dims = key_value.size(-1); + const c10::OptionalDeviceGuard device_guard(device_of(key_cache)); + const aclrtStream stream = c10_npu::getCurrentNPUStream().stream(); + + at::ScalarType scalar_type = key_value.scalar_type(); + at::ScalarType slot_type = slot_mapping.scalar_type(); + const char* socName = aclrtGetSocName(); + + at_npu::native::OpCommand cmd; + cmd.Name("load_and_reshape_flash_kernel"); + cmd.SetCustomHandler([scalar_type, slot_type, socName, stream, key_value_ptr, + key_cache_ptr, value_cache_ptr, slot_mapping_ptr, + hidden_dims, num_blocks, block_size, + num_tokens, num_layers, layer_idx]()->int { + auto slot_num = vllm_ascend::get_dtype_from_torch(slot_type); + auto dtype_num = vllm_ascend::get_dtype_from_torch(scalar_type); + auto ascendcPlatform = platform_ascendc::PlatformAscendCManager::GetInstance(socName); + uint32_t aiv_num = ascendcPlatform->GetCoreNumAiv(); + kvcache_ops::load_and_reshape_flash_kernel(dtype_num, slot_num, aiv_num, stream, key_value_ptr, + key_cache_ptr, value_cache_ptr, slot_mapping_ptr, + hidden_dims, num_blocks, block_size, + num_tokens, num_layers, layer_idx, true); + return 0; + }); + cmd.Run(); + return; +}; + +void reshape_and_cache_back_flash( + torch::Tensor& key_value, // [2, num_layer, num_tokens, num_heads*head_size] + // must be one gpu / pinned cpu + torch::Tensor& key_cache, // [num_blocks, block_size, num_heads, head_size] + torch::Tensor& value_cache, // [num_blocks, block_size, num_heads, head_size] + torch::Tensor& slot_mapping, // [num_tokens], + const int layer_idx) { + + uint8_t* key_value_ptr = get_kernel_ptr(key_value); + uint8_t* key_cache_ptr = get_kernel_ptr(key_cache); + uint8_t* value_cache_ptr = get_kernel_ptr(value_cache); + + uint8_t* slot_mapping_ptr = get_kernel_ptr(slot_mapping); + + int num_tokens = slot_mapping.size(0); + int num_layers = key_value.size(1); + int block_size = key_cache.size(1); + int num_blocks = key_cache.size(0); + int hidden_dims = key_value.size(-1); + const c10::OptionalDeviceGuard device_guard(device_of(key_cache)); + const aclrtStream stream = c10_npu::getCurrentNPUStream().stream(); + + at::ScalarType scalar_type = key_value.scalar_type(); + at::ScalarType slot_type = slot_mapping.scalar_type(); + + const char* socName = aclrtGetSocName(); + + at_npu::native::OpCommand cmd; + cmd.Name("reshape_and_cache_back_flash"); + cmd.SetCustomHandler([scalar_type, slot_type, socName, stream, key_value_ptr, + key_cache_ptr, value_cache_ptr, slot_mapping_ptr, + hidden_dims, num_blocks, block_size, + num_tokens, num_layers, layer_idx]() -> int { + auto slot_num = vllm_ascend::get_dtype_from_torch(slot_type); + auto dtype_num = vllm_ascend::get_dtype_from_torch(scalar_type); + auto ascendcPlatform = platform_ascendc::PlatformAscendCManager::GetInstance(socName); + uint32_t aiv_num = ascendcPlatform->GetCoreNumAiv(); + kvcache_ops::load_and_reshape_flash_kernel(dtype_num, slot_num, aiv_num, stream, key_value_ptr, + key_cache_ptr, value_cache_ptr, slot_mapping_ptr, + hidden_dims, num_blocks, block_size, + num_tokens, num_layers, layer_idx, false); + return 0; + }); + cmd.Run(); + return; +}; diff --git a/csrc/mem_kernels.h b/csrc/mem_kernels.h new file mode 100644 index 0000000..d291fc8 --- /dev/null +++ b/csrc/mem_kernels.h @@ -0,0 +1,58 @@ +#pragma once +#include +#include +#include "managed_mem.h" +#include "kernels/types.h" + +namespace kvcache_ops { +void multi_layer_kv_transfer_kernel(kvcache_ops::AscendType type, kvcache_ops::AscendType slotType, uint32_t blockDim, + void *stream, uint8_t *pagedKVCaches, uint8_t *dstCacheTensor, + uint8_t *slotmappings, const int64_t hiddenDims, const int32_t kvs, + const int32_t numLayers, const int64_t pageBuffSize, const int32_t numTokensChunk, + const bool page2L); + +void single_layer_kv_transfer_kernel(kvcache_ops::AscendType type, kvcache_ops::AscendType slotType, + uint32_t blockDim, void *stream, uint8_t *dstCacheTensor, + uint8_t *keyCachePtr, uint8_t *valueCachePtr, + uint8_t *slotmappings, const int64_t hiddenDims, const int32_t numTokens, + const bool page2L, const bool tokenMajor, const bool isMLA); + +void load_and_reshape_flash_kernel(kvcache_ops::AscendType type, kvcache_ops::AscendType slotType, + uint32_t blockDim, void *stream, uint8_t *dstCacheTensor, uint8_t *keyCachePtr, + uint8_t *valueCachePtr, uint8_t *slotmappings, const int64_t hiddenDims, + const int64_t numPages, const int32_t pagedSize, const int32_t numTokens, + const int32_t numLayers, const int32_t layerIdx, const bool page2L); +} + + +void multi_layer_kv_transfer(torch::Tensor& key_value, // [kv, num_layer, num_tokens, hidden] + const torch::Tensor& key_value_ptrs, // [num_layers] + const torch::Tensor& slot_mapping, // [num_tokens] + const torch::Device& paged_memory_device, + const int page_buffer_size, const bool direction, + const bool use_mla); + +void multi_layer_kv_transfer_unilateral(torch::Tensor& key_value, + const torch::Tensor& key_ptrs, + const torch::Tensor& value_ptrs, + const torch::Tensor& slot_mapping, + const torch::Device& paged_memory_device, + const int page_buffer_size, + const bool direction); + +void single_layer_kv_transfer(torch::Tensor& lmc_key_value_cache, + torch::Tensor& vllm_key_cache, + torch::Tensor& vllm_value_cache, + torch::Tensor& slot_mapping, + const bool direction, + const bool token_major = false); + +void load_and_reshape_flash(torch::Tensor& key_value, torch::Tensor& key_cache, + torch::Tensor& value_cache, + torch::Tensor& slot_mapping, const int layer_idx); + +void reshape_and_cache_back_flash(torch::Tensor& key_value, + torch::Tensor& key_cache, + torch::Tensor& value_cache, + torch::Tensor& slot_mapping, + const int layer_idx); \ No newline at end of file diff --git a/csrc/pos_kernels.cpp b/csrc/pos_kernels.cpp new file mode 100644 index 0000000..01ee897 --- /dev/null +++ b/csrc/pos_kernels.cpp @@ -0,0 +1,14 @@ +#include "pos_kernels.h" +#include +#include + +namespace py = pybind11; + +void rotary_embedding_k_fused(const torch::Tensor& old_positions, + const torch::Tensor& new_positions, + torch::Tensor& key, int64_t head_size, + const torch::Tensor& cos_sin_cache, bool is_neox) { + // TODO: + PyErr_SetString(PyExc_NotImplementedError, "Please contact LMCache Ascend."); + throw py::error_already_set(); +}; \ No newline at end of file diff --git a/csrc/pos_kernels.h b/csrc/pos_kernels.h new file mode 100644 index 0000000..b0bcaf1 --- /dev/null +++ b/csrc/pos_kernels.h @@ -0,0 +1,10 @@ +#pragma once +#include +#include +#include +#include + +void rotary_embedding_k_fused(const torch::Tensor& old_positions, + const torch::Tensor& new_positions, + torch::Tensor& key, int64_t head_size, + const torch::Tensor& cos_sin_cache, bool is_neox); \ No newline at end of file diff --git a/csrc/pybind.cpp b/csrc/pybind.cpp new file mode 100644 index 0000000..be39fb8 --- /dev/null +++ b/csrc/pybind.cpp @@ -0,0 +1,26 @@ +// SPDX-License-Identifier: Apache-2.0 + +#include +#include "mem_kernels.h" +#include "managed_mem.h" +#include "cachegen_kernels.h" +#include "pos_kernels.h" +#include +#include + +namespace py = pybind11; + +PYBIND11_MODULE(c_ops, m) { + m.def("host_register", ®ister_memory); + m.def("multi_layer_kv_transfer", &multi_layer_kv_transfer); + m.def("single_layer_kv_transfer", &single_layer_kv_transfer); + m.def("multi_layer_kv_transfer_unilateral", + &multi_layer_kv_transfer_unilateral); + m.def("load_and_reshape_flash", &load_and_reshape_flash); + m.def("reshape_and_cache_back_flash", &reshape_and_cache_back_flash); + m.def("encode_fast_new", &encode_cuda_new); + m.def("decode_fast_new", &decode_cuda_new); + m.def("decode_fast_prefsum", &decode_cuda_prefsum); + m.def("calculate_cdf", &calculate_cdf); + m.def("rotary_embedding_k_fused", &rotary_embedding_k_fused); +} \ No newline at end of file diff --git a/csrc/utils.cmake b/csrc/utils.cmake new file mode 100644 index 0000000..4b59496 --- /dev/null +++ b/csrc/utils.cmake @@ -0,0 +1,27 @@ +# +# Run `EXPR` in python. The standard output of python is stored in `OUT` and +# has trailing whitespace stripped. If an error is encountered when running +# python, a fatal message `ERR_MSG` is issued. +# +function (run_python OUT EXPR ERR_MSG) + execute_process( + COMMAND + "${PYTHON_EXECUTABLE}" "-c" "${EXPR}" + OUTPUT_VARIABLE PYTHON_OUT + RESULT_VARIABLE PYTHON_ERROR_CODE + ERROR_VARIABLE PYTHON_STDERR + OUTPUT_STRIP_TRAILING_WHITESPACE) + + if(NOT PYTHON_ERROR_CODE EQUAL 0) + message(FATAL_ERROR "${ERR_MSG}: ${PYTHON_STDERR}") + endif() + set(${OUT} ${PYTHON_OUT} PARENT_SCOPE) +endfunction() + +# Run `EXPR` in python after importing `PKG`. Use the result of this to extend +# `CMAKE_PREFIX_PATH` so the torch cmake configuration can be imported. +macro (append_cmake_prefix_path PKG EXPR) + run_python(_PREFIX_PATH + "import ${PKG}; print(${EXPR})" "Failed to locate ${PKG} path") + list(APPEND CMAKE_PREFIX_PATH ${_PREFIX_PATH}) +endmacro() diff --git a/csrc/utils.h b/csrc/utils.h new file mode 100644 index 0000000..b6938d4 --- /dev/null +++ b/csrc/utils.h @@ -0,0 +1,39 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "kernels/types.h" +#include +#include + +namespace vllm_ascend { +kvcache_ops::AscendType get_dtype_from_torch(at::ScalarType scalarType) +{ + if (scalarType == at::ScalarType::Float) { + return kvcache_ops::AscendType::FP32; + } else if (scalarType == at::ScalarType::BFloat16) { + return kvcache_ops::AscendType::BF16; + } else if (scalarType == at::ScalarType::Half) { + return kvcache_ops::AscendType::FP16; + } else if (scalarType == at::ScalarType::Long) { + return kvcache_ops::AscendType::INT64; + } else if (scalarType == at::ScalarType::Int) { + return kvcache_ops::AscendType::INT32; + } else { + TORCH_CHECK(false, "ScalarType not supported."); + } +}; +} // namespace vllm_ascend \ No newline at end of file diff --git a/docker/Dockerfile.a2.openEuler b/docker/Dockerfile.a2.openEuler new file mode 100644 index 0000000..a8d72b4 --- /dev/null +++ b/docker/Dockerfile.a2.openEuler @@ -0,0 +1,76 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +FROM quay.io/ascend/cann:8.2.rc1-910b-openeuler22.03-py3.11 + +ARG PIP_INDEX_URL="https://mirrors.aliyun.com/pypi/simple" +ARG COMPILE_CUSTOM_KERNELS=1 + +ENV COMPILE_CUSTOM_KERNELS=${COMPILE_CUSTOM_KERNELS} + +RUN yum update -y && \ + yum install -y python3-pip git vim wget net-tools gcc gcc-c++ make cmake numactl-devel && \ + rm -rf /var/cache/yum + +RUN pip config set global.index-url ${PIP_INDEX_URL} + +WORKDIR /workspace + +COPY . /workspace/LMCache-Ascend/ + +# Install vLLM +ARG VLLM_REPO=https://github.com/vllm-project/vllm.git +ARG VLLM_TAG=v0.9.2 +RUN git clone --depth 1 $VLLM_REPO --branch $VLLM_TAG /workspace/vllm +# In x86, triton will be installed by vllm. But in Ascend, triton doesn't work correctly. we need to uninstall it. +RUN VLLM_TARGET_DEVICE="empty" python3 -m pip install -e /workspace/vllm/ --extra-index https://download.pytorch.org/whl/cpu/ && \ + python3 -m pip uninstall -y triton + +# Install vLLM-Ascend +ARG VLLM_ASCEND_REPO=https://github.com/vllm-project/vllm-ascend.git +ARG VLLM_ASCEND_TAG=v0.9.2rc1 +RUN git clone --depth 1 $VLLM_ASCEND_REPO --branch $VLLM_ASCEND_TAG /workspace/vllm-ascend +RUN cd /workspace/vllm-ascend && \ + git apply -p1 /workspace/LMCache-Ascend/docker/kv-connector-v1.diff +RUN export PIP_EXTRA_INDEX_URL=https://mirrors.huaweicloud.com/ascend/repos/pypi && \ + source /usr/local/Ascend/ascend-toolkit/set_env.sh && \ + source /usr/local/Ascend/nnal/atb/set_env.sh && \ + export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/Ascend/ascend-toolkit/latest/`uname -i`-linux/devlib && \ + python3 -m pip install -v -e /workspace/vllm-ascend/ --extra-index https://download.pytorch.org/whl/cpu/ + +# Install modelscope (for fast download) and ray (for multinode) +RUN python3 -m pip install modelscope ray + +# Install LMCache +ARG LMCACHE_REPO=https://github.com/LMCache/LMCache.git +ARG LMCACHE_TAG=v0.3.3 +RUN git clone --depth 1 $LMCACHE_REPO --branch $LMCACHE_TAG /workspace/LMCache +# our build is based on arm64 +RUN sed -i "s/^infinistore$/infinistore; platform_machine == 'x86_64'/" /workspace/LMCache/requirements/common.txt +RUN export NO_CUDA_EXT=1 && python3 -m pip install -v -e /workspace/LMCache + +# Install LMCache-Ascend +RUN cd /workspace/LMCache-Ascend && \ + source /usr/local/Ascend/ascend-toolkit/set_env.sh && \ + source /usr/local/Ascend/nnal/atb/set_env.sh && \ + export SOC_VERSION=ASCEND910B3 && \ + export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/Ascend/ascend-toolkit/latest/`uname -i`-linux/devlib && \ + python3 -m pip install -v --no-build-isolation -e . && \ + python3 -m pip cache purge + +CMD ["/bin/bash"] + + diff --git a/docker/kv-connector-v1.diff b/docker/kv-connector-v1.diff new file mode 100644 index 0000000..95c5916 --- /dev/null +++ b/docker/kv-connector-v1.diff @@ -0,0 +1,219 @@ +diff --git a/vllm_ascend/attention/attention_v1.py b/vllm_ascend/attention/attention_v1.py +index 7d7f488..5b5d41c 100644 +--- a/vllm_ascend/attention/attention_v1.py ++++ b/vllm_ascend/attention/attention_v1.py +@@ -28,6 +28,10 @@ from vllm.forward_context import ForwardContext, get_forward_context + from vllm.utils import direct_register_custom_op + from vllm.v1.core.sched.output import SchedulerOutput + from vllm.v1.worker.gpu_input_batch import InputBatch ++from vllm.distributed.kv_transfer import (get_kv_transfer_group, ++ has_kv_transfer_group, ++ is_v1_kv_transfer_group) ++ + + from vllm_ascend.ops.attention import vanilla_chunked_prefill + from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, aligned_16, is_310p, +@@ -436,6 +440,34 @@ class AscendAttentionBackendImpl(AttentionImpl): + ori_output[:, :, :] = output[:num_tokens, :, :] + return output.view(num_tokens, self.hidden_size) + ++def wait_for_kv_layer_from_connector(layer_name: str): ++ if not has_kv_transfer_group() or not is_v1_kv_transfer_group(): ++ return ++ ++ connector = get_kv_transfer_group() ++ ++ forward_context: ForwardContext = get_forward_context() ++ attn_metadata = forward_context.attn_metadata ++ if attn_metadata is None: ++ return ++ #assert isinstance(attn_metadata, dict) ++ connector.wait_for_layer_load(layer_name) ++ ++def maybe_save_kv_layer_to_connector( ++ layer_name: str, ++ kv_cache_layer: List[torch.Tensor], ++): ++ if not has_kv_transfer_group() or not is_v1_kv_transfer_group(): ++ return ++ ++ connector = get_kv_transfer_group() ++ ++ forward_context: ForwardContext = get_forward_context() ++ attn_metadata = forward_context.attn_metadata ++ if attn_metadata is None: ++ return ++ connector.save_kv_layer(layer_name, kv_cache_layer, ++ attn_metadata) + + def unified_ascend_attention_with_output( + query: torch.Tensor, +@@ -444,6 +476,7 @@ def unified_ascend_attention_with_output( + output: torch.Tensor, + layer_name: str, + ) -> None: ++ wait_for_kv_layer_from_connector(layer_name) + forward_context: ForwardContext = get_forward_context() + attn_metadata = forward_context.attn_metadata + self = forward_context.no_compile_layers[layer_name] +@@ -456,6 +489,7 @@ def unified_ascend_attention_with_output( + attn_metadata, + output, + trace_flag=False) ++ maybe_save_kv_layer_to_connector(layer_name, kv_cache) + return + + +diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py +index eabcdbc..73a0ef6 100644 +--- a/vllm_ascend/worker/model_runner_v1.py ++++ b/vllm_ascend/worker/model_runner_v1.py +@@ -17,6 +17,7 @@ + # Adapted from vllm-project/vllm/vllm/worker/gpu_model_runner.py + # + ++import copy + import gc + import os + import time +@@ -57,6 +58,13 @@ from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler, + from vllm.v1.core.encoder_cache_manager import compute_encoder_budget + from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, + KVCacheSpec) ++from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1 ++from vllm.distributed.kv_transfer import (get_kv_transfer_group, ++ has_kv_transfer_group) ++from vllm.distributed.parallel_state import ( ++ get_pp_group, get_tp_group, graph_capture) ++from vllm.forward_context import (get_forward_context, ++ set_forward_context) + from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsTensors, + ModelRunnerOutput) + from vllm.v1.sample.metadata import SamplingMetadata +@@ -1103,6 +1111,8 @@ class NPUModelRunner(LoRAModelRunnerMixin): + with set_forward_context(attn_metadata, + self.vllm_config, + num_tokens=num_input_tokens): ++ self.maybe_setup_kv_connector(scheduler_output) ++ + with ProfileExecuteDuration().capture_async("forward"): + model_kwargs = {} + if self.torchair_graph_enabled: +@@ -1134,6 +1144,10 @@ class NPUModelRunner(LoRAModelRunnerMixin): + **model_kwargs, + ) + ++ self.maybe_wait_for_kv_save() ++ finished_sending, finished_recving = ( ++ self.get_finished_kv_transfers(scheduler_output)) ++ + use_spec_decode = len( + scheduler_output.scheduled_spec_decode_tokens) > 0 + if not use_spec_decode: +@@ -1163,7 +1177,7 @@ class NPUModelRunner(LoRAModelRunnerMixin): + + return (attn_metadata, hidden_states, spec_decode_metadata, positions, + total_num_scheduled_tokens, logits_indices, aux_hidden_states, +- num_scheduled_tokens) ++ num_scheduled_tokens, finished_sending, finished_recving) + + def _get_cumsum_and_arange( + self, +@@ -1350,6 +1364,8 @@ class NPUModelRunner(LoRAModelRunnerMixin): + hidden_states: torch.Tensor, + num_scheduled_tokens: int, + num_scheduled_tokens_np: np.ndarray, ++ finished_sending: Optional[set[str]], ++ finished_recving: Optional[set[str]], + ) -> ModelRunnerOutput: + assert self.input_batch.num_reqs ==\ + len(self.input_batch.pooling_params), \ +@@ -1384,6 +1400,8 @@ class NPUModelRunner(LoRAModelRunnerMixin): + logprobs=None, + prompt_logprobs_dict={}, + pooler_output=pooler_output, ++ finished_sending=finished_sending, ++ finished_recving=finished_recving + ) + + @torch.inference_mode() +@@ -1400,7 +1418,7 @@ class NPUModelRunner(LoRAModelRunnerMixin): + return EMPTY_MODEL_RUNNER_OUTPUT + (attn_metadata, hidden_states, spec_decode_metadata, positions, + num_scheduled_tokens, logits_indices, aux_hidden_states, +- num_scheduled_tokens_np) = (self._process_reqs( ++ num_scheduled_tokens_np, finished_sending, finished_recving) = (self._process_reqs( + scheduler_output, intermediate_tensors)) + + with ProfileExecuteDuration().capture_async("post process"): +@@ -1422,7 +1440,7 @@ class NPUModelRunner(LoRAModelRunnerMixin): + else: + if self.input_batch.pooling_params: + return self._pool(hidden_states, num_scheduled_tokens, +- num_scheduled_tokens_np) ++ num_scheduled_tokens_np, finished_sending, finished_recving) + sample_hidden_states = hidden_states[logits_indices] + logits = self.model.compute_logits(sample_hidden_states, None) + if broadcast_pp_output: +@@ -1561,6 +1579,8 @@ class NPUModelRunner(LoRAModelRunnerMixin): + logprobs=logprobs_lists, + prompt_logprobs_dict=prompt_logprobs_dict, + pooler_output=[], ++ finished_sending=finished_sending, ++ finished_recving=finished_recving + ) + + durations = ProfileExecuteDuration().pop_captured_sync() +@@ -1575,6 +1595,52 @@ class NPUModelRunner(LoRAModelRunnerMixin): + + return model_runner_output + ++ def kv_connector_no_forward( ++ self, scheduler_output: "SchedulerOutput") -> ModelRunnerOutput: ++ # KV send/recv even if no work to do. ++ with set_forward_context(None, self.vllm_config): ++ self.maybe_setup_kv_connector(scheduler_output) ++ finished_sending, finished_recving = ( ++ self.get_finished_kv_transfers(scheduler_output)) ++ ++ if not finished_sending and not finished_recving: ++ return EMPTY_MODEL_RUNNER_OUTPUT ++ ++ output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT) ++ output.finished_sending = finished_sending ++ output.finished_recving = finished_recving ++ return output ++ ++ @staticmethod ++ def maybe_setup_kv_connector(scheduler_output: "SchedulerOutput"): ++ # Update KVConnector with the KVConnector metadata forward(). ++ if has_kv_transfer_group(): ++ kv_connector = get_kv_transfer_group() ++ assert isinstance(kv_connector, KVConnectorBase_V1) ++ assert scheduler_output.kv_connector_metadata is not None ++ kv_connector.bind_connector_metadata( ++ scheduler_output.kv_connector_metadata) ++ ++ # Background KV cache transfers happen here. ++ # These transfers are designed to be async and the requests ++ # involved may be disjoint from the running requests. ++ # Do this here to save a collective_rpc. ++ kv_connector.start_load_kv(get_forward_context()) ++ ++ @staticmethod ++ def maybe_wait_for_kv_save() -> None: ++ if has_kv_transfer_group(): ++ get_kv_transfer_group().wait_for_save() ++ ++ @staticmethod ++ def get_finished_kv_transfers( ++ scheduler_output: "SchedulerOutput", ++ ) -> tuple[Optional[set[str]], Optional[set[str]]]: ++ if has_kv_transfer_group(): ++ return get_kv_transfer_group().get_finished( ++ scheduler_output.finished_req_ids) ++ return None, None ++ + @torch.inference_mode() + def _dummy_run( + self, diff --git a/lmcache_ascend/__init__.py b/lmcache_ascend/__init__.py new file mode 100644 index 0000000..174b823 --- /dev/null +++ b/lmcache_ascend/__init__.py @@ -0,0 +1,21 @@ +# SPDX-License-Identifier: Apache-2.0 +import sys +import lmcache +import lmcache_ascend +import lmcache_ascend.c_ops as ascend_c_ops + +sys.modules["lmcache.c_ops"] = ascend_c_ops + +from lmcache_ascend.v1.cache_engine import _ascend_create_memory_allocator +import lmcache.v1.cache_engine + +lmcache.v1.cache_engine.LMCacheEngineBuilder._Create_memory_allocator = ( + _ascend_create_memory_allocator +) + +from lmcache_ascend.integration.vllm.vllm_v1_adapter import ( + init_lmcache_engine as ascend_init_lmcache_engine, +) +import lmcache.integration.vllm.vllm_adapter + +lmcache.integration.vllm.vllm_adapter.init_lmcache_engine = ascend_init_lmcache_engine diff --git a/lmcache_ascend/integration/vllm/__init__.py b/lmcache_ascend/integration/vllm/__init__.py new file mode 100644 index 0000000..9881313 --- /dev/null +++ b/lmcache_ascend/integration/vllm/__init__.py @@ -0,0 +1 @@ +# SPDX-License-Identifier: Apache-2.0 diff --git a/lmcache_ascend/integration/vllm/lmcache_ascend_connector_v1.py b/lmcache_ascend/integration/vllm/lmcache_ascend_connector_v1.py new file mode 100644 index 0000000..0bd682a --- /dev/null +++ b/lmcache_ascend/integration/vllm/lmcache_ascend_connector_v1.py @@ -0,0 +1,21 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Third Party +from vllm.config import VllmConfig +from vllm.distributed.kv_transfer.kv_connector.v1.base import ( + KVConnectorBase_V1, + KVConnectorMetadata, + KVConnectorRole, +) +from vllm.logger import init_logger + +# First Party +import lmcache_ascend +from lmcache.integration.vllm.lmcache_connector_v1 import LMCacheConnectorV1Dynamic + +logger = init_logger(__name__) + + +class LMCacheAscendConnectorV1Dynamic(LMCacheConnectorV1Dynamic): + def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole) -> None: + super().__init__(vllm_config=vllm_config, role=role) diff --git a/lmcache_ascend/integration/vllm/vllm_v1_adapter.py b/lmcache_ascend/integration/vllm/vllm_v1_adapter.py new file mode 100644 index 0000000..151712f --- /dev/null +++ b/lmcache_ascend/integration/vllm/vllm_v1_adapter.py @@ -0,0 +1,159 @@ +# SPDX-License-Identifier: Apache-2.0 +# Standard +from typing import TYPE_CHECKING, List, Optional, Tuple, Union + +# Third Party +# TODO: Currently we patch all the cuda calls due to effort to port all torch.cuda +# will disabled torch.jit +from torch_npu.contrib import transfer_to_npu + +# First Party +from lmcache.config import LMCacheEngineMetadata +from lmcache.integration.vllm.utils import ENGINE_NAME, lmcache_get_config +from lmcache.integration.vllm.vllm_adapter import ( + VLLM_CACHE_CONFIG, + VLLM_MODEL_CONFIG, + VLLM_PARALLEL_CONFIG, + VLLM_SCHEDULER_CONFIG, + need_gpu_interm_buffer, +) +from lmcache.v1.cache_engine import LMCacheEngine, LMCacheEngineBuilder +from lmcache.v1.config import LMCacheEngineConfig +from lmcache.v1.gpu_connector import ( + VLLMBufferLayerwiseGPUConnector, +) +from lmcache_ascend.v1.npu_connector import ( + VLLMPagedMemNPUConnectorV2, + VLLMPagedMemLayerwiseNPUConnector, +) +from lmcache.logging import init_logger + +# Third Party +import torch +from vllm.config import ( + CacheConfig, + ModelConfig, + ParallelConfig, + SchedulerConfig, +) +from vllm.utils import get_kv_cache_torch_dtype + +logger = init_logger(__name__) + + +# We need to patch this function due to connector modification +def init_lmcache_engine( + model_config: ModelConfig, + parallel_config: ParallelConfig, + cache_config: CacheConfig, + scheduler_config: SchedulerConfig, +) -> Optional[LMCacheEngine]: + """Initialize the LMCache engine by the given model config and parallel + config. This function will check the environment variable + `LMCACHE_CONFIG_FILE` to load the configuration file. If that environment + variable is not set, this function will return None. + + :param model_config: The model configuration in vLLM. + :type model_config: ModelConfig + :param parallel_config: The parallel configuration in vLLM. + :type parallel_config: ParallelConfig + :param cache_config: The KV cache configuration in vLLM. + :type cache_config: CacheConfig + :param scheduler_config: The scheduler configuration in vLLM. + :type scheduler_config: SchedulerConfig + + :return: The initialized LMCache engine or None (if the environment variable + `LMCACHE_CONFIG_FILE` is not set). + :rtype: Optional[LMCacheEngine] + """ + if LMCacheEngineBuilder.get(ENGINE_NAME) is not None: + return None + + global VLLM_CACHE_CONFIG + global VLLM_PARALLEL_CONFIG + global VLLM_MODEL_CONFIG + global VLLM_SCHEDULER_CONFIG + VLLM_CACHE_CONFIG = cache_config + VLLM_PARALLEL_CONFIG = parallel_config + VLLM_MODEL_CONFIG = model_config + VLLM_SCHEDULER_CONFIG = scheduler_config + + config = lmcache_get_config() + + assert isinstance(config, LMCacheEngineConfig), ( + "LMCache v1 configuration is should be passed." + ) + + kv_dtype = get_kv_cache_torch_dtype(cache_config.cache_dtype, model_config.dtype) + + use_mla = False + if ( + hasattr(model_config, "use_mla") + and isinstance(model_config.use_mla, bool) + and model_config.use_mla + ): + use_mla = True + + if use_mla and (config.remote_serde != "naive" and config.remote_serde is not None): + raise ValueError("MLA only works with naive serde mode..") + + # construct kv shape (for mem pool) + num_layer = model_config.get_num_layers(parallel_config) + chunk_size = config.chunk_size + num_kv_head = model_config.get_num_kv_heads(parallel_config) + head_size = model_config.get_head_size() + kv_shape = (num_layer, 1 if use_mla else 2, chunk_size, num_kv_head, head_size) + logger.info(f"use mla: {use_mla}, kv shape: {kv_shape}") + + # Change current device. + torch.cuda.device(parallel_config.rank) + device = torch.device(f"cuda:{parallel_config.rank}") + metadata = LMCacheEngineMetadata( + model_config.model, + parallel_config.world_size, + parallel_config.rank, + "vllm", + kv_dtype, + kv_shape, + use_mla, + ) + + use_gpu = need_gpu_interm_buffer(config) + vllm_gpu_connector: Union[ + VLLMPagedMemNPUConnectorV2, + VLLMPagedMemLayerwiseNPUConnector, + ] + + if use_mla and config.use_layerwise: + raise ValueError("layerwise MLA connector is not supported yet") + + # When use_mla is True, num_kv_head is 1 + hidden_dim_size = num_kv_head * head_size + if config.use_layerwise: + if config.enable_blending: + # Use layerwise connector for blending + raise NotImplementedError("Blending is not yet supported for Ascend.") + else: + vllm_gpu_connector = VLLMPagedMemLayerwiseNPUConnector( + hidden_dim_size, + num_layer, + use_gpu=use_gpu, + chunk_size=chunk_size, + dtype=kv_dtype, + device=device, + ) + else: + vllm_gpu_connector = VLLMPagedMemNPUConnectorV2( + hidden_dim_size, + num_layer, + use_gpu=use_gpu, + chunk_size=chunk_size, + dtype=kv_dtype, + device=device, + use_mla=use_mla, + ) + engine = LMCacheEngineBuilder.get_or_create( + ENGINE_NAME, config, metadata, vllm_gpu_connector + ) + + return engine diff --git a/lmcache_ascend/v1/__init__.py b/lmcache_ascend/v1/__init__.py new file mode 100644 index 0000000..9881313 --- /dev/null +++ b/lmcache_ascend/v1/__init__.py @@ -0,0 +1 @@ +# SPDX-License-Identifier: Apache-2.0 diff --git a/lmcache_ascend/v1/cache_engine.py b/lmcache_ascend/v1/cache_engine.py new file mode 100644 index 0000000..54e3a1e --- /dev/null +++ b/lmcache_ascend/v1/cache_engine.py @@ -0,0 +1,21 @@ +# SPDX-License-Identifier: Apache-2.0 + +# First Party +from lmcache.config import LMCacheEngineMetadata +from lmcache.v1.config import LMCacheEngineConfig +from lmcache.v1.memory_management import MemoryAllocatorInterface +from .memory_management import AscendMixedMemoryAllocator + + +def _ascend_create_memory_allocator( + config: LMCacheEngineConfig, + metadata: LMCacheEngineMetadata, +) -> MemoryAllocatorInterface: + if config.enable_nixl: + raise NotImplementedError("Ascend does not support nixl.") + + if config.weka_path is not None or config.gds_path is not None: + raise NotImplementedError("Ascend does not support Direct Storage.") + + max_local_cpu_size = config.max_local_cpu_size + return AscendMixedMemoryAllocator(int(max_local_cpu_size * 1024**3)) diff --git a/lmcache_ascend/v1/memory_management.py b/lmcache_ascend/v1/memory_management.py new file mode 100644 index 0000000..c673609 --- /dev/null +++ b/lmcache_ascend/v1/memory_management.py @@ -0,0 +1,93 @@ +# SPDX-License-Identifier: Apache-2.0 +# Standard +from contextlib import nullcontext +import threading + +# Third Party +import torch + +# First Party +from lmcache.logging import init_logger +from lmcache.v1.memory_management import ( + MixedMemoryAllocator, + PagedTensorMemoryAllocator, + TensorMemoryAllocator, + BufferAllocator, + PinMemoryAllocator, +) +import lmcache_ascend.c_ops as lmc_ops + +logger = init_logger(__name__) + + +# NOTE (Gingfung): it is not really used in v1, mainly for testing. +class AscendPinMemoryAllocator(PinMemoryAllocator): + """Allocates memory in the pre-allocated pinned memory.""" + + def __init__(self, size: int, use_paging: bool = False, **kwargs): + """ + :param int size: The size of the pinned memory in bytes. + """ + + self.buffer = torch.empty( + size, dtype=torch.uint8, device="cpu", pin_memory=True + ) + lmc_ops.host_register(self.buffer) + + if use_paging: + assert "shape" in kwargs, ( + "shape must be specified for paged memory allocator" + ) + assert "dtype" in kwargs, ( + "dtype must be specified for paged memory allocator" + ) + assert "fmt" in kwargs, "fmt must be specified for paged memory allocator" + self.allocator = PagedTensorMemoryAllocator( + tensor=self.buffer, + shape=kwargs["shape"], + dtype=kwargs["dtype"], + fmt=kwargs["fmt"], + ) + else: + self.allocator = TensorMemoryAllocator(self.buffer) + + self.host_mem_lock = threading.Lock() if not use_paging else nullcontext() + + def close(self): + pass + + +class AscendMixedMemoryAllocator(MixedMemoryAllocator): + def __init__(self, size: int, use_paging: bool = False, **kwargs) -> None: + """ + :param int size: The size of the pinned memory in bytes. + """ + + self.buffer = torch.empty( + size, dtype=torch.uint8, device="cpu", pin_memory=True + ) + lmc_ops.host_register(self.buffer) + + if use_paging: + assert "shape" in kwargs, ( + "shape must be specified for paged memory allocator" + ) + assert "dtype" in kwargs, ( + "dtype must be specified for paged memory allocator" + ) + assert "fmt" in kwargs, "fmt must be specified for paged memory allocator" + self.pin_allocator = PagedTensorMemoryAllocator( + tensor=self.buffer, + shape=kwargs["shape"], + dtype=kwargs["dtype"], + fmt=kwargs["fmt"], + ) + else: + self.pin_allocator = TensorMemoryAllocator(self.buffer) + + self.host_mem_lock = threading.Lock() if not use_paging else nullcontext() + + self.buffer_allocator = BufferAllocator("cpu") + + def close(self): + pass diff --git a/lmcache_ascend/v1/npu_connector.py b/lmcache_ascend/v1/npu_connector.py new file mode 100644 index 0000000..cc4d859 --- /dev/null +++ b/lmcache_ascend/v1/npu_connector.py @@ -0,0 +1,259 @@ +# SPDX-License-Identifier: Apache-2.0 +# Standard +from typing import List, Union + +# Third Party +import torch + +# First Party +from lmcache.logging import init_logger +from lmcache.v1.gpu_connector import ( + VLLMPagedMemGPUConnectorV2, + VLLMPagedMemLayerwiseGPUConnector, +) +from lmcache.v1.memory_management import MemoryFormat, MemoryObj +import lmcache_ascend.c_ops as lmc_ops + +logger = init_logger(__name__) + + +class VLLMPagedMemNPUConnectorV2(VLLMPagedMemGPUConnectorV2): + def _initialize_pointers(self, kv_caches: List[torch.Tensor]) -> torch.Tensor: + self.kv_cache_pointers.numpy()[:] = [t.data_ptr() for t in kv_caches] + device = kv_caches[0].device + assert device.type == "npu", "The device should be Ascend NPU." + idx = device.index + if idx not in self.kv_cache_pointers_on_gpu: + self.kv_cache_pointers_on_gpu[idx] = torch.empty( + self.num_layers, dtype=torch.int64, device=device + ) + self.kv_cache_pointers_on_gpu[idx].copy_(self.kv_cache_pointers) + if self.use_mla: + # kv_caches[0].shape: [num_pages, page_size, head_size] + # kv_caches[0].shape: [1, num_pages, page_size, head_size] (vllm-Ascend) + self.page_buffer_size = kv_caches[0].shape[-3] * kv_caches[0].shape[-2] + else: + # kv_caches[0].shape: [2, num_pages, page_size, num_heads, head_size] + assert kv_caches[0].dim() == 5 + self.page_buffer_size = kv_caches[0].shape[1] * kv_caches[0].shape[2] + + return self.kv_cache_pointers_on_gpu[idx] + + +class VLLMPagedMemLayerwiseNPUConnector(VLLMPagedMemLayerwiseGPUConnector): + def batched_to_gpu(self, starts: List[int], ends: List[int], **kwargs): + """ + This function is a generator that moves the KV cache from the memory + objects to paged GPU memory. The first iteration will prepare some + related metadata. In each of the following iterations, it will first + wait until the loading of the previous layer finish, and then load + one layer of KV cache from the memory objects -> GPU buffer -> + paged GPU memory. The last iteration simply waits for the last layer + to finish. + In total, this the generator will yield num_layers + 2 times. + + :param starts: The starting indices of the KV cache in the corresponding + token sequence. + + :param ends: The ending indices of the KV cache in the corresponding + token sequence. + + :raises ValueError: If 'slot_mapping' is not provided in kwargs. + """ + + self.initialize_kvcaches_ptr(**kwargs) + assert self.kvcaches is not None, ( + "kvcaches should be provided in kwargs or initialized beforehand." + ) + + if "slot_mapping" not in kwargs: + raise ValueError("'slot_mapping' should be provided in kwargs.") + + if "sync" not in kwargs: + raise ValueError("'sync' should be provided in kwargs.") + + slot_mapping: torch.Tensor = kwargs["slot_mapping"] + sync: bool = kwargs["sync"] + + self._lazy_initialize_buffer(self.kvcaches) + + slot_mapping_chunks = [] + for start, end in zip(starts, ends, strict=False): + slot_mapping_chunks.append(slot_mapping[start:end]) + + # TODO(Jiayi): Optimize away this `cat` + slot_mapping_full = torch.cat(slot_mapping_chunks, dim=0) + + num_tokens = len(slot_mapping_full) + + if self.use_gpu: + buffer_shape = self.get_shape(num_tokens) + tmp_gpu_buffer_obj = self.gpu_buffer_allocator.allocate( + buffer_shape, self.dtype, MemoryFormat.KV_T2D + ) + assert tmp_gpu_buffer_obj is not None, ( + "Failed to allocate GPU buffer in GPUConnector" + ) + assert tmp_gpu_buffer_obj.tensor is not None + + offset = starts[0] + current_stream = torch.cuda.current_stream() + + for layer_id in range(self.num_layers): + memory_objs_layer = yield + if sync: + current_stream.wait_stream(self.load_stream) + if layer_id > 0: + logger.debug(f"Finished loading layer {layer_id - 1}") + + # memobj -> gpu_buffer -> kvcaches + with torch.cuda.stream(self.load_stream): + for start, end, memory_obj in zip( + starts, ends, memory_objs_layer, strict=False + ): + assert memory_obj.metadata.fmt == MemoryFormat.KV_T2D + if self.use_gpu: + tmp_gpu_buffer_obj.tensor[start - offset : end - offset].copy_( + memory_obj.tensor, non_blocking=True + ) + else: + lmc_ops.single_layer_kv_transfer( + memory_obj.tensor, + self.kvcaches[layer_id][0], + self.kvcaches[layer_id][1], + slot_mapping[start:end], + False, + True, + ) + + if self.use_gpu: + lmc_ops.single_layer_kv_transfer( + tmp_gpu_buffer_obj.tensor, + self.kvcaches[layer_id][0], + self.kvcaches[layer_id][1], + slot_mapping_full, + False, + True, + ) + yield + + # synchronize the last layer + if sync: + current_stream.wait_stream(self.load_stream) + + # free the buffer memory + if self.use_gpu: + tmp_gpu_buffer_obj.ref_count_down() + + logger.debug(f"Finished loading layer {layer_id}") + yield + + def batched_from_gpu( + self, + memory_objs: Union[List[List[MemoryObj]], List[MemoryObj]], + starts: List[int], + ends: List[int], + **kwargs, + ): + """ + This function is a generator that moves the KV cache from the paged GPU + memory to the memory objects. The first iteration will prepare some + related metadata and initiate the transfer in the first layer. In each + of the following iterations, it will first wait until the storing of + previous layer finishes, and then initiate string the KV cache of the + current layer one. The storing process of the KV cache is paged GPU + memory -> GPU buffer -> memory objects. The last iteration simply waits + for the last layer to finish. + In total, this the generator will yield num_layers + 1 times. + + :param memory_objs: The memory objects to store the KV cache. The first + dimension is the number of layers, and the second dimension is the + number of memory objects (i.e., number of chunks) for each layer. + + :param starts: The starting indices of the KV cache in the corresponding + token sequence. + + :param ends: The ending indices of the KV cache in the corresponding + token sequence. + + :raises ValueError: If 'slot_mapping' is not provided in kwargs. + """ + self.initialize_kvcaches_ptr(**kwargs) + assert self.kvcaches is not None, ( + "kvcaches should be provided in kwargs or initialized beforehand." + ) + + if "slot_mapping" not in kwargs: + raise ValueError("'slot_mapping' should be provided in kwargs.") + + if "sync" not in kwargs: + raise ValueError("'sync' should be provided in kwargs.") + + slot_mapping: torch.Tensor = kwargs["slot_mapping"] + sync: bool = kwargs["sync"] + + self._lazy_initialize_buffer(self.kvcaches) + + slot_mapping_chunks = [] + for start, end in zip(starts, ends, strict=False): + slot_mapping_chunks.append(slot_mapping[start:end]) + + slot_mapping_full = torch.cat(slot_mapping_chunks, dim=0) + + num_tokens = len(slot_mapping_full) + + if self.use_gpu: + buffer_shape = self.get_shape(num_tokens) + tmp_gpu_buffer_obj = self.gpu_buffer_allocator.allocate( + buffer_shape, self.dtype, MemoryFormat.KV_T2D + ) + assert tmp_gpu_buffer_obj is not None, ( + "Failed to allocate GPU buffer in GPUConnector" + ) + assert tmp_gpu_buffer_obj.tensor is not None + + offset = starts[0] + current_stream = torch.cuda.current_stream() + + for layer_id in range(self.num_layers): + memory_objs_layer = memory_objs[layer_id] + # kvcaches -> gpu_buffer -> memobj + with torch.cuda.stream(self.store_stream): + self.store_stream.wait_stream(current_stream) + if self.use_gpu: + lmc_ops.single_layer_kv_transfer( + tmp_gpu_buffer_obj.tensor, + self.kvcaches[layer_id][0], + self.kvcaches[layer_id][1], + slot_mapping_full, + True, + True, + ) + for start, end, memory_obj in zip( + starts, ends, memory_objs_layer, strict=False + ): + assert memory_obj.tensor is not None + if self.use_gpu: + memory_obj.tensor.copy_( + tmp_gpu_buffer_obj.tensor[start - offset : end - offset], + non_blocking=True, + ) + else: + lmc_ops.single_layer_kv_transfer( + memory_obj.tensor, + self.kvcaches[layer_id][0], + self.kvcaches[layer_id][1], + slot_mapping[start:end], + True, + True, + ) + + yield + if sync: + self.store_stream.synchronize() + logger.debug(f"Finished offloading layer {layer_id}") + + # free the buffer memory + if self.use_gpu: + tmp_gpu_buffer_obj.ref_count_down() + yield diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..7e89e5b --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,118 @@ +[build-system] +requires = [ + "setuptools>=61.0", + "setuptools_scm[toml]>=6.2", + "torch", + "torch-npu", + "wheel" +] +build-backend = "setuptools.build_meta" + +[project] +name = "lmcache-ascend" +authors = [ + { name = "LMCache Team", email = "lmcacheteam@gmail.com" }, + { name = "GingFung Matthew Yeung", email = "gingfung.matthew.yeung@huawei.com" }, +] +license = { text = "Apache-2.0" } +readme = "README.md" +description = "LMCache on Ascend NPU" +classifiers = [ + "Development Status :: 3 - Alpha", + "Operating System :: POSIX :: Linux", + "Environment :: GPU", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: Implementation :: CPython", +] +requires-python = ">=3.10,<3.13" +dynamic = ["dependencies", "version"] + +[project.urls] +homepage = "https://docs.lmcache.ai" +source = "https://github.com/LMCache/LMCache-Ascend" +issues = "https://github.com/LMCache/LMCache-Ascend" + +[tool.setuptools_scm] +version_file = "lmcache_ascend/_version.py" +# do not include +gREV local version, required for Test PyPI upload +local_scheme = "no-local-version" + +[tool.setuptools.packages.find] +where = [""] +include = ["lmcache_ascend", "lmcache_ascend*"] + +[tool.ruff] +# same as Black's default line length +line-length = 88 + +[tool.ruff.lint] +select = [ + # pycodestyle + "E", + # Pyflakes + "F", + # pyupgrade + # "UP", + # flake8-bugbear + "B", + # flake8-simplify + #"SIM", + # Ruff does not support isort's import_headings feature, yet. + # "I", + # flake8-logging-format + #"G", +] +ignore = [ + # star imports + "F405", "F403", + # lambda expression assignment + "E731", + # Loop control variable not used within loop body + "B007", + # f-string format + "UP032", +] + +[tool.ruff.lint.isort] +# same as .isort.cfg +from-first = true +# not supported yet +# import-heading-future=Future +# import-heading-stdlib=Standard +# import-heading-thirdparty=Third Party +# import-heading-firstparty=First Party +# import-heading-localfolder=Local + +[tool.mypy] +modules = ["lmcache_ascend", "tests"] + +# TODO: tighten MyPy checks by enabling these checks over time. +disable_error_code = [ + "annotation-unchecked", + "union-attr", + "var-annotated", + "arg-type", + "call-arg", + "import-untyped", + "attr-defined", + "return-value", + "assignment", + "call-overload", + "misc", +] + +ignore_missing_imports = true +explicit_package_bases = true + +# TODO: tighten MyPy checks by enabling these checks over time. +check_untyped_defs = false +disallow_incomplete_defs = false +disallow_untyped_defs = false +disallow_untyped_calls = false +warn_return_any = false + +follow_imports = "silent" diff --git a/requirement.txt b/requirement.txt new file mode 100644 index 0000000..4d49b62 --- /dev/null +++ b/requirement.txt @@ -0,0 +1,2 @@ +torch +torch_npu diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..1547f1f --- /dev/null +++ b/setup.py @@ -0,0 +1,195 @@ +# SPDX-License-Identifier: Apache-2.0 +# Standard +from pathlib import Path +import os +import sys + +# Third Party +from setuptools import find_packages, setup, Extension +from setuptools.command.build_ext import build_ext +from setuptools.command.develop import develop +from setuptools.command.install import install + +import logging +import sysconfig +import subprocess +import platform +import shutil + + +ROOT_DIR = Path(__file__).parent + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +def _get_ascend_home_path(): + # NOTE: standard Ascend CANN toolkit path + return os.environ.get("ASCEND_HOME_PATH", "/usr/local/Ascend/ascend-toolkit/latest") + + +def _get_ascend_env_path(): + # NOTE: standard Ascend Environment variable setup path + env_script_path = os.path.realpath( + os.path.join(_get_ascend_home_path(), "..", "set_env.sh") + ) + if not os.path.exists(env_script_path): + raise ValueError( + f"The file '{env_script_path}' is not found, " + "please make sure environment variable 'ASCEND_HOME_PATH' is set correctly." + ) + return env_script_path + + +def _get_npu_soc(): + _soc_version = os.getenv("SOC_VERSION", None) + if _soc_version is None: + npu_smi_cmd = [ + "bash", + "-c", + "npu-smi info | grep OK | awk '{print $3}' | head -n 1", + ] + try: + _soc_version = subprocess.check_output(npu_smi_cmd, text=True).strip() + _soc_version = _soc_version.split("-")[0] + _soc_version = "Ascend" + _soc_version + return _soc_version + except subprocess.CalledProcessError as e: + raise RuntimeError(f"Retrieve SoC version failed: {e}") + return _soc_version + + +class CMakeExtension(Extension): + def __init__(self, name: str, cmake_lists_dir: str = ".", **kwargs) -> None: + super().__init__(name, sources=[], py_limited_api=False, **kwargs) + self.cmake_lists_dir = os.path.abspath(cmake_lists_dir) + + +class custom_install(install): + def run(self): + self.run_command("build_ext") + install.run(self) + + +class CustomAscendCmakeBuildExt(build_ext): + def build_extension(self, ext): + # build the so as c_ops + ext_name = ext.name.split(".")[-1] + so_name = ext_name + ".so" + logger.info(f"Building {so_name} ...") + BUILD_OPS_DIR = os.path.join(ROOT_DIR, "build") + os.makedirs(BUILD_OPS_DIR, exist_ok=True) + + ascend_home_path = _get_ascend_home_path() + env_path = _get_ascend_env_path() + _soc_version = _get_npu_soc() + _cxx_compiler = os.getenv("CXX") + _cc_compiler = os.getenv("CC") + python_executable = sys.executable + + try: + # if pybind11 is installed via pip + pybind11_cmake_path = ( + subprocess.check_output( + [python_executable, "-m", "pybind11", "--cmakedir"] + ) + .decode() + .strip() + ) + except subprocess.CalledProcessError as e: + # else specify pybind11 path installed from source code on CI container + raise RuntimeError(f"CMake configuration failed: {e}") + + import torch_npu + + torch_npu_path = os.path.dirname(os.path.abspath(torch_npu.__file__)) + import torch + + torch_path = os.path.dirname(os.path.abspath(torch.__file__)) + + # python include + python_include_path = sysconfig.get_path("include", scheme="posix_prefix") + + arch = platform.machine() + install_path = os.path.join(BUILD_OPS_DIR, "install") + if isinstance(self.distribution.get_command_obj("develop"), develop): + install_path = BUILD_OPS_DIR + + cmake_cmd = [ + f"source {env_path} && " + f"cmake -S {ROOT_DIR} -B {BUILD_OPS_DIR}" + f" -DSOC_VERSION={_soc_version}" + f" -DARCH={arch}" + " -DUSE_ASCEND=1" + f" -DPYTHON_EXECUTABLE={python_executable}" + f" -DCMAKE_PREFIX_PATH={pybind11_cmake_path}" + f" -DCMAKE_BUILD_TYPE=Release" + f" -DCMAKE_INSTALL_PREFIX={install_path}" + f" -DPYTHON_INCLUDE_PATH={python_include_path}" + f" -DTORCH_NPU_PATH={torch_npu_path}" + f" -DTORCH_PATH={torch_path}" + f" -DASCEND_CANN_PACKAGE_PATH={ascend_home_path}" + " -DCMAKE_VERBOSE_MAKEFILE=ON" + ] + + if _cxx_compiler is not None: + cmake_cmd += [f" -DCMAKE_CXX_COMPILER={_cxx_compiler}"] + + if _cc_compiler is not None: + cmake_cmd += [f" -DCMAKE_C_COMPILER={_cc_compiler}"] + + cmake_cmd += [f" && cmake --build {BUILD_OPS_DIR} -j --verbose"] + cmake_cmd += [f" && cmake --install {BUILD_OPS_DIR}"] + cmake_cmd = "".join(cmake_cmd) + + logger.info(f"Start running CMake commands:\n{cmake_cmd}") + try: + _ = subprocess.run( + cmake_cmd, cwd=ROOT_DIR, text=True, shell=True, check=True + ) + except subprocess.CalledProcessError as e: + raise RuntimeError(f"Failed to build {so_name}: {e}") + + build_lib_dir = self.get_ext_fullpath(ext.name) + os.makedirs(os.path.dirname(build_lib_dir), exist_ok=True) + + package_name = ext.name.split(".")[0] # e.g., 'lmcache' + src_dir = os.path.join(ROOT_DIR, package_name) + + for root, _, files in os.walk(install_path): + for file in files: + if file.endswith(".so"): + src_path = os.path.join(root, file) + dst_path = os.path.join(os.path.dirname(build_lib_dir), file) + if os.path.exists(dst_path): + os.remove(dst_path) + + if isinstance( + self.distribution.get_command_obj("develop"), develop + ): + # For the ascend kernels + src_dir_file = os.path.join(src_dir, file) + shutil.copy(src_path, src_dir_file) + shutil.copy(src_path, dst_path) + + logger.info(f"Copied {file} to {dst_path}") + + +def ascend_extension(): + print("Building Ascend extensions") + return [CMakeExtension(name="lmcache_ascend.c_ops")], { + "build_ext": CustomAscendCmakeBuildExt + } + + +if __name__ == "__main__": + ext_modules, cmdclass = ascend_extension() + + setup( + packages=find_packages( + exclude=("csrc",) + ), # Ensure csrc is excluded if it only contains sources + ext_modules=ext_modules, + cmdclass=cmdclass, + include_package_data=True, + ) diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..6c46c1a --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,253 @@ +# SPDX-License-Identifier: Apache-2.0 +# From LMCache +# Standard +from dataclasses import dataclass +from unittest.mock import patch +import random +import shlex +import socket +import subprocess +import time + +# Third Party +import pytest +import torch_npu + +# First Party +import lmcache_ascend +from lmcache.v1.cache_engine import LMCacheEngineBuilder + + +class MockRedis: + def __init__( + self, host=None, port=None, url=None, decode_responses=False, **kwargs + ): + self.store = {} + self.host = host + self.port = port + self.url = url + self.decode_responses = decode_responses + + def set(self, key, value): + self.store[key] = value + return True + + def get(self, key): + return self.store.get(key, None) + + def exists(self, key): + return key in self.store + + def scan(self, cursor=0, match=None): + keys = [s.encode("utf-8") for s in self.store.keys()] + return (0, keys) + + def close(self): + pass + + @classmethod + def from_url(cls, url, decode_responses=False, **kwargs): + """Mock implementation of Redis.from_url""" + return cls(url=url, decode_responses=decode_responses, **kwargs) + + +class MockRedisSentinel: + def __init__(self, hosts_and_ports, socket_timeout=None, **kwargs): + self.redis = MockRedis() + self.hosts_and_ports = hosts_and_ports + self.socket_timeout = socket_timeout + + def master_for( + self, service_name, socket_timeout=None, username=None, password=None, **kwargs + ): + return self.redis + + def slave_for( + self, service_name, socket_timeout=None, username=None, password=None, **kwargs + ): + return self.redis + + +@dataclass +class LMCacheServerProcess: + server_url: str + server_process: object + + +@pytest.fixture(scope="function", autouse=True) +def mock_redis(): + with ( + patch("redis.Redis", MockRedis) as mock_redis_class, + patch("redis.from_url", MockRedis.from_url), + ): + yield mock_redis_class + + +@pytest.fixture(scope="function", autouse=True) +def mock_redis_sentinel(): + with patch("redis.Sentinel", MockRedisSentinel) as mock: + yield mock + + +@pytest.fixture(scope="module") +def lmserver_v1_process(request): + def ensure_connection(host, port): + retries = 10 + client_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + successful = False + while retries > 0: + retries -= 1 + try: + print("Probing connection, remaining retries: ", retries) + client_socket.connect((host, port)) + successful = True + break + except ConnectionRefusedError: + time.sleep(1) + print("Connection refused!") + continue + except Exception as e: + print(f"other Exception: {e}") + continue + + client_socket.close() + return successful + + # Specify remote device + device = request.param + + # Start the process + max_retries = 5 + while max_retries > 0: + max_retries -= 1 + port_number = random.randint(10000, 65500) + print("Starting the lmcache v1 server process on port") + proc = subprocess.Popen( + shlex.split( + f"python3 -m lmcache.v1.server localhost {port_number} {device}" + ) + ) + + # Wait for lmcache process to start + time.sleep(5) + + successful = False + if proc.poll() is not None: + successful = True + else: + successful = ensure_connection("localhost", port_number) + + if not successful: + proc.terminate() + proc.wait() + else: + break + + # Yield control back to the test until it finishes + server_url = f"lm://localhost:{port_number}" + yield LMCacheServerProcess(server_url, proc) + + # Terminate the process + proc.terminate() + proc.wait() + + # Destroy remote disk path + if device not in ["cpu"]: + subprocess.run(shlex.split(f"rm -rf {device}")) + + +@pytest.fixture(scope="module") +def lmserver_process(request): + def ensure_connection(host, port): + retries = 10 + client_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + successful = False + while retries > 0: + retries -= 1 + try: + print("Probing connection, remaining retries: ", retries) + client_socket.connect((host, port)) + successful = True + break + except ConnectionRefusedError: + time.sleep(1) + print("Connection refused!") + continue + except Exception as e: + print(f"other Exception: {e}") + continue + + client_socket.close() + return successful + + # Specify remote device + device = request.param + + # Start the process + max_retries = 5 + while max_retries > 0: + max_retries -= 1 + port_number = random.randint(10000, 65500) + print("Starting the lmcache server process on port") + proc = subprocess.Popen( + shlex.split(f"python3 -m lmcache.server localhost {port_number} {device}") + ) + + # Wait for lmcache process to start + time.sleep(5) + + successful = False + if proc.poll() is not None: + successful = True + else: + successful = ensure_connection("localhost", port_number) + + if not successful: + proc.terminate() + proc.wait() + else: + break + + # Yield control back to the test until it finishes + server_url = f"lm://localhost:{port_number}" + yield LMCacheServerProcess(server_url, proc) + + # Terminate the process + proc.terminate() + proc.wait() + + # Destroy remote disk path + if device not in ["cpu"]: + subprocess.run(shlex.split(f"rm -rf {device}")) + + +@pytest.fixture(scope="function") +def autorelease(request): + objects = [] + + def _factory(obj): + objects.append(obj) + return obj + + yield _factory + + # Cleanup all objects created by the factory + for obj in objects: + obj.close() + + +@pytest.fixture(scope="function") +def autorelease_v1(request): + objects = [] + + def _factory(obj): + objects.append(obj) + return obj + + yield _factory + + LMCacheEngineBuilder.destroy("test") + + # Cleanup all objects created by the factory + # for obj in objects: + # obj.close() diff --git a/tests/v1/storage_backend/test_local_cpu_backend.py b/tests/v1/storage_backend/test_local_cpu_backend.py new file mode 100644 index 0000000..948b3e7 --- /dev/null +++ b/tests/v1/storage_backend/test_local_cpu_backend.py @@ -0,0 +1,548 @@ +# SPDX-License-Identifier: Apache-2.0 +# Standard +import threading + +# Third Party +import pytest +import torch + +# First Party +from lmcache.utils import CacheEngineKey +from lmcache.v1.config import LMCacheEngineConfig +from lmcache.v1.memory_management import ( + AdHocMemoryAllocator, + MemoryFormat, + MemoryObj, +) +from lmcache_ascend.v1.memory_management import ( + AscendMixedMemoryAllocator as MixedMemoryAllocator, +) +from lmcache.v1.storage_backend.local_cpu_backend import LocalCPUBackend + + +class MockLookupServer: + def __init__(self): + self.removed_keys = [] + self.inserted_keys = [] + + def batched_remove(self, keys): + self.removed_keys.extend(keys) + + def batched_insert(self, keys): + self.inserted_keys.extend(keys) + + +class MockLMCacheWorker: + def __init__(self): + self.messages = [] + + def put_msg(self, msg): + self.messages.append(msg) + + +def create_test_config( + local_cpu: bool = True, use_layerwise: bool = False, enable_blending: bool = False +): + """Create a test configuration for LocalCPUBackend.""" + config = LMCacheEngineConfig.from_defaults( + chunk_size=256, + local_cpu=local_cpu, + use_layerwise=use_layerwise, + enable_blending=enable_blending, + lmcache_instance_id="test_instance", + ) + return config + + +def create_test_key(key_id: str = "test_key") -> CacheEngineKey: + """Create a test CacheEngineKey.""" + return CacheEngineKey("vllm", "test_model", 3, 123, key_id) + + +def create_test_memory_obj(shape=(2, 16, 8, 128), dtype=torch.bfloat16) -> MemoryObj: + """Create a test MemoryObj using AdHocMemoryAllocator for testing.""" + allocator = AdHocMemoryAllocator(device="cpu") + memory_obj = allocator.allocate(shape, dtype, fmt=MemoryFormat.KV_T2D) + return memory_obj + + +@pytest.fixture +def memory_allocator(): + """Create a memory allocator for testing.""" + return MixedMemoryAllocator(1024 * 1024 * 1024) # 1GB + + +@pytest.fixture +def local_cpu_backend(memory_allocator): + """Create a LocalCPUBackend for testing.""" + config = create_test_config() + return LocalCPUBackend(config=config, memory_allocator=memory_allocator) + + +@pytest.fixture +def local_cpu_backend_disabled(memory_allocator): + """Create a LocalCPUBackend with local_cpu disabled.""" + config = create_test_config(local_cpu=False) + return LocalCPUBackend(config=config, memory_allocator=memory_allocator) + + +class TestLocalCPUBackend: + """Test cases for LocalCPUBackend.""" + + def test_init(self, memory_allocator): + """Test LocalCPUBackend initialization.""" + config = create_test_config() + backend = LocalCPUBackend(config=config, memory_allocator=memory_allocator) + + assert backend.use_hot is True + assert backend.lookup_server is None + assert backend.memory_allocator == memory_allocator + assert backend.lmcache_worker is None + assert backend.instance_id == "test_instance" + assert backend.usage == 0 + assert len(backend.hot_cache) == 0 + assert backend.layerwise is False + assert backend.enable_blending is False + + memory_allocator.close() + + def test_init_with_lookup_server_and_worker(self, memory_allocator): + """Test LocalCPUBackend initialization with lookup server and worker.""" + config = create_test_config() + lookup_server = MockLookupServer() + lmcache_worker = MockLMCacheWorker() + + backend = LocalCPUBackend( + config=config, + memory_allocator=memory_allocator, + lookup_server=lookup_server, + lmcache_worker=lmcache_worker, + ) + + assert backend.lookup_server == lookup_server + assert backend.lmcache_worker == lmcache_worker + + memory_allocator.close() + + def test_str(self, local_cpu_backend): + """Test string representation.""" + assert str(local_cpu_backend) == "LocalCPUBackend" + + local_cpu_backend.memory_allocator.close() + + def test_contains_key_not_exists(self, local_cpu_backend): + """Test contains() when key doesn't exist.""" + key = create_test_key("nonexistent") + assert not local_cpu_backend.contains(key) + assert not local_cpu_backend.contains(key, pin=True) + + local_cpu_backend.memory_allocator.close() + + def test_contains_key_exists(self, local_cpu_backend): + """Test contains() when key exists.""" + key = create_test_key("test_key") + memory_obj = create_test_memory_obj() + + # Insert key first + local_cpu_backend.submit_put_task(key, memory_obj) + + assert local_cpu_backend.contains(key) + assert local_cpu_backend.contains(key, pin=True) + + local_cpu_backend.memory_allocator.close() + + def test_exists_in_put_tasks(self, local_cpu_backend): + """Test exists_in_put_tasks().""" + key = create_test_key("test_key") + # LocalCPUBackend always returns False for exists_in_put_tasks + assert not local_cpu_backend.exists_in_put_tasks(key) + local_cpu_backend.memory_allocator.close() + + def test_submit_put_task(self, local_cpu_backend): + """Test submit_put_task().""" + key = create_test_key("test_key") + memory_obj = create_test_memory_obj() + + future = local_cpu_backend.submit_put_task(key, memory_obj) + + # LocalCPUBackend returns None for submit_put_task + assert future is None + assert key in local_cpu_backend.hot_cache + assert local_cpu_backend.hot_cache[key] == memory_obj + assert ( + memory_obj.get_ref_count() == 2 + ) # 1 from creation + 1 from submit_put_task + local_cpu_backend.memory_allocator.close() + + def test_submit_put_task_reinsert(self, local_cpu_backend): + """Test submit_put_task() with reinsertion.""" + key = create_test_key("test_key") + memory_obj1 = create_test_memory_obj(shape=(2, 16, 8, 128)) + memory_obj2 = create_test_memory_obj(shape=(2, 32, 8, 128)) + + # First insertion + local_cpu_backend.submit_put_task(key, memory_obj1) + assert local_cpu_backend.hot_cache[key] == memory_obj1 + + # Reinsertion + local_cpu_backend.submit_put_task(key, memory_obj2) + assert local_cpu_backend.hot_cache[key] != memory_obj2 + assert memory_obj1.get_ref_count() == 2 + assert memory_obj2.get_ref_count() == 1 + + local_cpu_backend.memory_allocator.close() + + def test_batched_submit_put_task(self, local_cpu_backend): + """Test batched_submit_put_task().""" + keys = [create_test_key(f"key_{i}") for i in range(3)] + memory_objs = [create_test_memory_obj() for _ in range(3)] + + futures = local_cpu_backend.batched_submit_put_task(keys, memory_objs) + + # LocalCPUBackend returns None for batched_submit_put_task + assert futures is None + + # Check that all keys were inserted + for key, memory_obj in zip(keys, memory_objs, strict=False): + assert key in local_cpu_backend.hot_cache + assert local_cpu_backend.hot_cache[key] == memory_obj + + local_cpu_backend.memory_allocator.close() + + def test_batched_submit_put_task_disabled(self, local_cpu_backend_disabled): + """Test batched_submit_put_task() when local_cpu is disabled.""" + keys = [create_test_key(f"key_{i}") for i in range(3)] + memory_objs = [create_test_memory_obj() for _ in range(3)] + + futures = local_cpu_backend_disabled.batched_submit_put_task(keys, memory_objs) + + # Should return None when local_cpu is disabled + assert futures is None + + local_cpu_backend_disabled.memory_allocator.close() + + def test_submit_prefetch_task(self, local_cpu_backend): + """Test submit_prefetch_task().""" + key = create_test_key("test_key") + ret = local_cpu_backend.submit_prefetch_task(key) + + # LocalCPUBackend always returns None for submit_prefetch_task + assert ret is False + + local_cpu_backend.memory_allocator.close() + + def test_get_blocking_key_not_exists(self, local_cpu_backend): + """Test get_blocking() when key doesn't exist.""" + key = create_test_key("nonexistent") + result = local_cpu_backend.get_blocking(key) + + assert result is None + + local_cpu_backend.memory_allocator.close() + + def test_get_blocking_key_exists(self, local_cpu_backend): + """Test get_blocking() when key exists.""" + key = create_test_key("test_key") + memory_obj = create_test_memory_obj() + + # Insert key first + local_cpu_backend.submit_put_task(key, memory_obj) + + result = local_cpu_backend.get_blocking(key) + + assert result is not None + assert isinstance(result, MemoryObj) + assert result == memory_obj + assert ( + result.get_ref_count() == 3 + ) # 1 from creation + 1 from submit_put_task + 1 from get_blocking + + local_cpu_backend.memory_allocator.close() + + def test_get_non_blocking_key_not_exists(self, local_cpu_backend): + """Test get_non_blocking() when key doesn't exist.""" + key = create_test_key("nonexistent") + future = local_cpu_backend.get_non_blocking(key) + + assert future is None + + local_cpu_backend.memory_allocator.close() + + def test_get_non_blocking_key_exists(self, local_cpu_backend): + """Test get_non_blocking() when key exists.""" + key = create_test_key("test_key") + memory_obj = create_test_memory_obj() + + # Insert key first + local_cpu_backend.submit_put_task(key, memory_obj) + + future = local_cpu_backend.get_non_blocking(key) + + assert future is not None + result = future.result() + assert result is not None + assert isinstance(result, MemoryObj) + assert result == memory_obj + + local_cpu_backend.memory_allocator.close() + + def test_pin_unpin(self, local_cpu_backend): + """Test pin() and unpin() operations.""" + key = create_test_key("test_key") + memory_obj = create_test_memory_obj() + + # Insert key first + local_cpu_backend.submit_put_task(key, memory_obj) + + # Test pin + assert local_cpu_backend.pin(key) + assert memory_obj.is_pinned + + # Test unpin + assert local_cpu_backend.unpin(key) + assert not memory_obj.is_pinned + + # Test pin/unpin non-existent key + non_existent_key = create_test_key("non_existent") + assert not local_cpu_backend.pin(non_existent_key) + assert not local_cpu_backend.unpin(non_existent_key) + + local_cpu_backend.memory_allocator.close() + + def test_remove(self, local_cpu_backend): + """Test remove().""" + key = create_test_key("test_key") + memory_obj = create_test_memory_obj() + + # Insert key first + local_cpu_backend.submit_put_task(key, memory_obj) + assert key in local_cpu_backend.hot_cache + + # Remove the key + result = local_cpu_backend.remove(key) + + assert result is True + assert key not in local_cpu_backend.hot_cache + assert memory_obj.get_ref_count() == 1 # Should be decremented + + local_cpu_backend.memory_allocator.close() + + def test_remove_non_existent(self, local_cpu_backend): + """Test remove() with non-existent key.""" + key = create_test_key("nonexistent") + result = local_cpu_backend.remove(key) + + assert result is False + + local_cpu_backend.memory_allocator.close() + + def test_remove_without_free(self, local_cpu_backend): + """Test remove() with free_obj=False.""" + key = create_test_key("test_key") + memory_obj = create_test_memory_obj() + + # Insert key first + local_cpu_backend.submit_put_task(key, memory_obj) + initial_ref_count = memory_obj.get_ref_count() + + # Remove the key without freeing the object + result = local_cpu_backend.remove(key, free_obj=False) + + assert result is True + assert key not in local_cpu_backend.hot_cache + assert ( + memory_obj.get_ref_count() == initial_ref_count + ) # Should not be decremented + + local_cpu_backend.memory_allocator.close() + + def test_remove_with_worker(self, memory_allocator): + """Test remove() with LMCacheWorker.""" + config = create_test_config() + lmcache_worker = MockLMCacheWorker() + backend = LocalCPUBackend( + config=config, + memory_allocator=memory_allocator, + lmcache_worker=lmcache_worker, + ) + + key = create_test_key("test_key") + memory_obj = create_test_memory_obj() + + # Insert key first + backend.submit_put_task(key, memory_obj) + + # Remove the key + backend.remove(key) + + # Check that evict message was sent + assert len(lmcache_worker.messages) == 2 # 1 admit + 1 evict + # First Party + from lmcache.v1.cache_controller.message import KVAdmitMsg, KVEvictMsg + + assert any(isinstance(msg, KVAdmitMsg) for msg in lmcache_worker.messages) + assert any(isinstance(msg, KVEvictMsg) for msg in lmcache_worker.messages) + + memory_allocator.close() + + def test_allocate(self, local_cpu_backend): + """Test allocate().""" + shape = torch.Size([2, 16, 8, 128]) + dtype = torch.bfloat16 + + memory_obj = local_cpu_backend.allocate(shape, dtype) + + assert memory_obj is not None + assert isinstance(memory_obj, MemoryObj) + assert memory_obj.metadata.shape == shape + assert memory_obj.metadata.dtype == dtype + + local_cpu_backend.memory_allocator.close() + + def test_allocate_with_format(self, local_cpu_backend): + """Test allocate() with specific format.""" + shape = torch.Size([2, 16, 8, 128]) + dtype = torch.bfloat16 + fmt = MemoryFormat.KV_2LTD + + memory_obj = local_cpu_backend.allocate(shape, dtype, fmt) + + assert memory_obj is not None + assert memory_obj.metadata.fmt == fmt + + local_cpu_backend.memory_allocator.close() + + def test_batched_allocate(self, local_cpu_backend): + """Test batched_allocate().""" + shape = torch.Size([2, 16, 8, 128]) + dtype = torch.bfloat16 + batch_size = 3 + + memory_objs = local_cpu_backend.batched_allocate(shape, dtype, batch_size) + + assert memory_objs is not None + assert len(memory_objs) == batch_size + for memory_obj in memory_objs: + assert isinstance(memory_obj, MemoryObj) + assert memory_obj.metadata.shape == shape + assert memory_obj.metadata.dtype == dtype + + local_cpu_backend.memory_allocator.close() + + def test_get_keys(self, local_cpu_backend): + """Test get_keys().""" + keys = [create_test_key(f"key_{i}") for i in range(3)] + memory_objs = [create_test_memory_obj() for _ in range(3)] + + # Insert keys + for key, memory_obj in zip(keys, memory_objs, strict=False): + local_cpu_backend.submit_put_task(key, memory_obj) + + # Get keys + retrieved_keys = local_cpu_backend.get_keys() + + assert len(retrieved_keys) == 3 + assert all(key in retrieved_keys for key in keys) + + local_cpu_backend.memory_allocator.close() + + def test_get_keys_empty(self, local_cpu_backend): + """Test get_keys() when cache is empty.""" + keys = local_cpu_backend.get_keys() + + assert len(keys) == 0 + + local_cpu_backend.memory_allocator.close() + + def test_concurrent_access(self, local_cpu_backend): + """Test concurrent access to the backend.""" + key = create_test_key("test_key") + memory_obj = create_test_memory_obj() + + # Insert key + local_cpu_backend.submit_put_task(key, memory_obj) + + # Test concurrent contains() calls + def check_contains(): + for _ in range(20): + assert local_cpu_backend.contains(key) + + threads = [threading.Thread(target=check_contains) for _ in range(3)] + for thread in threads: + thread.start() + for thread in threads: + thread.join() + + local_cpu_backend.memory_allocator.close() + + def test_thread_safety(self, local_cpu_backend): + """Test thread safety of the backend.""" + key = create_test_key("test_key") + memory_obj = create_test_memory_obj() + + # Insert key + local_cpu_backend.submit_put_task(key, memory_obj) + + # Test concurrent operations + def concurrent_operations(): + for _ in range(10): + # Test contains + local_cpu_backend.contains(key) + # Test pin/unpin + local_cpu_backend.pin(key) + local_cpu_backend.unpin(key) + # Test get_blocking + result = local_cpu_backend.get_blocking(key) + assert result is not None + + threads = [threading.Thread(target=concurrent_operations) for _ in range(3)] + for thread in threads: + thread.start() + for thread in threads: + thread.join() + + # The backend should still be in a consistent state + assert local_cpu_backend.contains(key) + + local_cpu_backend.memory_allocator.close() + + def test_memory_usage_tracking(self, local_cpu_backend): + """Test that memory usage is tracked correctly.""" + key = create_test_key("test_key") + memory_obj = create_test_memory_obj() + + initial_usage = local_cpu_backend.usage + + # Insert key + local_cpu_backend.submit_put_task(key, memory_obj) + + # Usage should be updated + assert local_cpu_backend.usage > initial_usage + + # Remove key + local_cpu_backend.remove(key) + + # Usage should be reduced + assert local_cpu_backend.usage == initial_usage + + local_cpu_backend.memory_allocator.close() + + def test_ref_count_management(self, local_cpu_backend): + """Test reference count management.""" + key = create_test_key("test_key") + memory_obj = create_test_memory_obj() + + initial_ref_count = memory_obj.get_ref_count() + + # Insert key + local_cpu_backend.submit_put_task(key, memory_obj) + assert memory_obj.get_ref_count() == initial_ref_count + 1 + + # Get blocking + local_cpu_backend.get_blocking(key) + assert memory_obj.get_ref_count() == initial_ref_count + 2 + + # Remove key + local_cpu_backend.remove(key) + assert memory_obj.get_ref_count() == initial_ref_count + 1 + local_cpu_backend.memory_allocator.close() diff --git a/tests/v1/storage_backend/test_local_disk_backend.py b/tests/v1/storage_backend/test_local_disk_backend.py new file mode 100644 index 0000000..8cd036a --- /dev/null +++ b/tests/v1/storage_backend/test_local_disk_backend.py @@ -0,0 +1,629 @@ +# SPDX-License-Identifier: Apache-2.0 +# Standard +import asyncio +import os +import shutil +import tempfile +import threading + +# Third Party +import pytest +import torch + +# First Party +from lmcache.utils import CacheEngineKey +from lmcache.v1.config import LMCacheEngineConfig +from lmcache.v1.memory_management import ( + MemoryFormat, + MemoryObj, +) +from lmcache_ascend.v1.memory_management import ( + AscendMixedMemoryAllocator as MixedMemoryAllocator, +) +from lmcache.v1.storage_backend.local_cpu_backend import LocalCPUBackend +from lmcache.v1.storage_backend.local_disk_backend import LocalDiskBackend + + +class MockLookupServer: + def __init__(self): + self.removed_keys = [] + self.inserted_keys = [] + + def batched_remove(self, keys): + self.removed_keys.extend(keys) + + def batched_insert(self, keys): + self.inserted_keys.extend(keys) + + +class MockLMCacheWorker: + def __init__(self): + self.messages = [] + + def put_msg(self, msg): + self.messages.append(msg) + + +def create_test_config(disk_path: str, max_disk_size: float = 1.0): + """Create a test configuration for LocalDiskBackend.""" + config = LMCacheEngineConfig.from_defaults( + chunk_size=256, + local_disk=disk_path, + max_local_disk_size=max_disk_size, + lmcache_instance_id="test_instance", + ) + return config + + +def create_test_key(key_id: str = "test_key") -> CacheEngineKey: + """Create a test CacheEngineKey.""" + return CacheEngineKey("vllm", "test_model", 3, 123, key_id) + + +def create_test_memory_obj(shape=(2, 16, 8, 128), dtype=torch.bfloat16) -> MemoryObj: + """Create a test MemoryObj using AdHocMemoryAllocator for testing.""" + # First Party + from lmcache.v1.memory_management import AdHocMemoryAllocator, MemoryFormat + + allocator = AdHocMemoryAllocator(device="cpu") + memory_obj = allocator.allocate(shape, dtype, fmt=MemoryFormat.KV_T2D) + return memory_obj + + +@pytest.fixture +def temp_disk_path(): + """Create a temporary directory for disk storage tests.""" + temp_dir = tempfile.mkdtemp() + yield temp_dir + # Cleanup + if os.path.exists(temp_dir): + shutil.rmtree(temp_dir) + + +@pytest.fixture +def async_loop(): + """Create an asyncio event loop for testing.""" + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + yield loop + loop.close() + + +@pytest.fixture +def local_cpu_backend(): + """Create a LocalCPUBackend for testing.""" + config = LMCacheEngineConfig.from_legacy(chunk_size=256) + memory_allocator = MixedMemoryAllocator(1024 * 1024 * 1024) # 1GB + return LocalCPUBackend(config, memory_allocator) + + +@pytest.fixture +def local_disk_backend(temp_disk_path, async_loop, local_cpu_backend): + """Create a LocalDiskBackend for testing.""" + config = create_test_config(temp_disk_path) + return LocalDiskBackend( + config=config, + loop=async_loop, + local_cpu_backend=local_cpu_backend, + dst_device="cuda", + ) + + +class TestLocalDiskBackend: + """Test cases for LocalDiskBackend.""" + + def test_init(self, temp_disk_path, async_loop, local_cpu_backend): + """Test LocalDiskBackend initialization.""" + config = create_test_config(temp_disk_path) + backend = LocalDiskBackend( + config=config, + loop=async_loop, + local_cpu_backend=local_cpu_backend, + dst_device="cuda", + ) + + assert backend.dst_device == "cuda" + assert backend.local_cpu_backend == local_cpu_backend + assert backend.path == temp_disk_path + assert os.path.exists(temp_disk_path) + assert backend.lookup_server is None + assert backend.lmcache_worker is None + assert backend.instance_id == "test_instance" + assert backend.usage == 0 + assert len(backend.dict) == 0 + + local_cpu_backend.memory_allocator.close() + + def test_init_with_lookup_server_and_worker( + self, temp_disk_path, async_loop, local_cpu_backend + ): + """Test LocalDiskBackend initialization with lookup server and worker.""" + config = create_test_config(temp_disk_path) + lookup_server = MockLookupServer() + lmcache_worker = MockLMCacheWorker() + + backend = LocalDiskBackend( + config=config, + loop=async_loop, + local_cpu_backend=local_cpu_backend, + dst_device="cuda", + lookup_server=lookup_server, + lmcache_worker=lmcache_worker, + ) + + assert backend.lookup_server == lookup_server + assert backend.lmcache_worker == lmcache_worker + + local_cpu_backend.memory_allocator.close() + + def test_str(self, local_disk_backend): + """Test string representation.""" + assert str(local_disk_backend) == "LocalDiskBackend" + local_disk_backend.local_cpu_backend.memory_allocator.close() + + def test_key_to_path(self, local_disk_backend): + """Test key to path conversion.""" + key = create_test_key("test_hash") + path = local_disk_backend._key_to_path(key) + + expected_filename = key.to_string().replace("/", "-") + ".pt" + assert path == os.path.join(local_disk_backend.path, expected_filename) + + local_disk_backend.local_cpu_backend.memory_allocator.close() + + def test_contains_key_not_exists(self, local_disk_backend): + """Test contains() when key doesn't exist.""" + key = create_test_key("nonexistent") + assert not local_disk_backend.contains(key) + assert not local_disk_backend.contains(key, pin=True) + + local_disk_backend.local_cpu_backend.memory_allocator.close() + + def test_contains_key_exists(self, local_disk_backend): + """Test contains() when key exists.""" + key = create_test_key("test_key") + memory_obj = create_test_memory_obj() + + # Insert key first + local_disk_backend.insert_key(key, memory_obj) + + assert local_disk_backend.contains(key) + assert local_disk_backend.contains(key, pin=True) + + local_disk_backend.local_cpu_backend.memory_allocator.close() + + def test_pin_unpin(self, local_disk_backend): + """Test pin() and unpin() operations.""" + key = create_test_key("test_key") + memory_obj = create_test_memory_obj() + # Insert key first + local_disk_backend.insert_key(key, memory_obj) + # Test pin + assert local_disk_backend.pin(key) + assert local_disk_backend.dict[key].pin_count > 0 + # Test unpin + assert local_disk_backend.unpin(key) + assert local_disk_backend.dict[key].pin_count == 0 + + # Test pin/unpin non-existent key + non_existent_key = create_test_key("non_existent") + assert not local_disk_backend.pin(non_existent_key) + assert not local_disk_backend.unpin(non_existent_key) + + local_disk_backend.local_cpu_backend.memory_allocator.close() + + def test_insert_key(self, local_disk_backend): + """Test insert_key().""" + key = create_test_key("test_key") + memory_obj = create_test_memory_obj() + local_disk_backend.insert_key(key, memory_obj) + assert key in local_disk_backend.dict + metadata = local_disk_backend.dict[key] + assert metadata.path == local_disk_backend._key_to_path(key) + assert metadata.size == memory_obj.get_size() + assert metadata.shape == memory_obj.metadata.shape + assert metadata.dtype == memory_obj.metadata.dtype + assert metadata.fmt == memory_obj.metadata.fmt + assert metadata.pin_count == 0 + local_disk_backend.local_cpu_backend.memory_allocator.close() + + def test_insert_key_reinsert(self, local_disk_backend): + """Test insert_key() with reinsertion.""" + key = create_test_key("test_key") + memory_obj1 = create_test_memory_obj(shape=(2, 16, 8, 128)) + memory_obj2 = create_test_memory_obj(shape=(2, 32, 8, 128)) + + # First insertion + local_disk_backend.insert_key(key, memory_obj1) + original_path = local_disk_backend.dict[key].path + + # Reinsertion + local_disk_backend.insert_key(key, memory_obj2) + + assert key in local_disk_backend.dict + metadata = local_disk_backend.dict[key] + assert metadata.path == original_path # Path should remain the same + assert metadata.size == memory_obj2.get_size() # Size should be updated + + local_disk_backend.local_cpu_backend.memory_allocator.close() + + def test_remove(self, local_disk_backend): + """Test remove().""" + key = create_test_key("test_key") + memory_obj = create_test_memory_obj() + + # Insert key first + local_disk_backend.insert_key(key, memory_obj) + assert key in local_disk_backend.dict + + # Create a dummy file to simulate the disk file + path = local_disk_backend._key_to_path(key) + with open(path, "wb") as f: + f.write(b"dummy data") + + # Remove the key + local_disk_backend.remove(key) + + assert key not in local_disk_backend.dict + assert not os.path.exists(path) + + local_disk_backend.local_cpu_backend.memory_allocator.close() + + def test_remove_with_worker(self, temp_disk_path, async_loop, local_cpu_backend): + """Test remove() with LMCacheWorker.""" + config = create_test_config(temp_disk_path) + lmcache_worker = MockLMCacheWorker() + backend = LocalDiskBackend( + config=config, + loop=async_loop, + local_cpu_backend=local_cpu_backend, + dst_device="cuda", + lmcache_worker=lmcache_worker, + ) + key = create_test_key("test_key") + memory_obj = create_test_memory_obj() + # Insert key first + backend.insert_key(key, memory_obj) + # Create a dummy file + path = backend._key_to_path(key) + with open(path, "wb") as f: + f.write(b"dummy data") + # Remove the key + backend.remove(key) + # Check that both admit and evict messages were sent + assert len(lmcache_worker.messages) == 2 + # First Party + from lmcache.v1.cache_controller.message import KVAdmitMsg, KVEvictMsg + + assert any(isinstance(msg, KVAdmitMsg) for msg in lmcache_worker.messages) + assert any(isinstance(msg, KVEvictMsg) for msg in lmcache_worker.messages) + + local_cpu_backend.memory_allocator.close() + + def test_submit_put_task(self, local_disk_backend): + """Test submit_put_task() synchronous""" + key = create_test_key("test_key") + memory_obj = create_test_memory_obj() + + # Test that the key is not in put_tasks initially + assert not local_disk_backend.exists_in_put_tasks(key) + + # Test that the key doesn't exist in the backend initially + assert not local_disk_backend.contains(key) + + # Use insert_key directly to test the synchronous path + local_disk_backend.insert_key(key, memory_obj) + + # Check that the key was inserted into the backend + assert local_disk_backend.contains(key) + assert key in local_disk_backend.dict + + # Check that the metadata was properly set + metadata = local_disk_backend.dict[key] + assert metadata.path == local_disk_backend._key_to_path(key) + assert metadata.size == memory_obj.get_size() + assert metadata.shape == memory_obj.metadata.shape + assert metadata.dtype == memory_obj.metadata.dtype + assert metadata.fmt == memory_obj.metadata.fmt + assert metadata.pin_count == 0 + + # Test that the key is still not in put_tasks + # (since we used insert_key directly) + assert not local_disk_backend.exists_in_put_tasks(key) + + local_disk_backend.local_cpu_backend.memory_allocator.close() + + def test_submit_put_task_with_eviction( + self, temp_disk_path, async_loop, local_cpu_backend + ): + """Test submit_put_task() with eviction.""" + config = create_test_config( + temp_disk_path, max_disk_size=0.001 + ) # Very small size + backend = LocalDiskBackend( + config=config, + loop=async_loop, + local_cpu_backend=local_cpu_backend, + dst_device="cuda", + ) + + # Add multiple keys to trigger eviction + for i in range(5): + key = create_test_key(f"key_{i}") + memory_obj = create_test_memory_obj() + backend.insert_key(key, memory_obj) + + # Test that the evictor is working by checking the cache size + # The evictor should manage the cache size based on max_disk_size + assert len(backend.dict) <= 5 + + # Test that the evictor is properly initialized + assert backend.evictor is not None + + local_cpu_backend.memory_allocator.close() + + def test_submit_prefetch_task_key_not_exists(self, local_disk_backend): + """Test submit_prefetch_task() when key doesn't exist.""" + key = create_test_key("nonexistent") + res = local_disk_backend.submit_prefetch_task(key) + + assert not res + + local_disk_backend.local_cpu_backend.memory_allocator.close() + + def test_submit_prefetch_task_key_exists(self, local_disk_backend): + """Test submit_prefetch_task() when key exists.""" + key = create_test_key("test_key") + memory_obj = create_test_memory_obj() + + # Insert key first + local_disk_backend.insert_key(key, memory_obj) + + # Create the actual file on disk + path = local_disk_backend._key_to_path(key) + with open(path, "wb") as f: + f.write(memory_obj.byte_array) + + future = local_disk_backend.submit_prefetch_task(key) + + assert future is not None + # Don't call future.result() to avoid blocking + + local_disk_backend.local_cpu_backend.memory_allocator.close() + + def test_get_blocking_key_not_exists(self, local_disk_backend): + """Test get_blocking() when key doesn't exist.""" + key = create_test_key("nonexistent") + result = local_disk_backend.get_blocking(key) + + assert result is None + + local_disk_backend.local_cpu_backend.memory_allocator.close() + + def test_get_blocking_key_exists(self, local_disk_backend): + """Test get_blocking() when key exists.""" + key = create_test_key("test_key") + memory_obj = create_test_memory_obj() + + # Insert key first + local_disk_backend.insert_key(key, memory_obj) + + # Create the actual file on disk + path = local_disk_backend._key_to_path(key) + with open(path, "wb") as f: + f.write(memory_obj.byte_array) + + result = local_disk_backend.get_blocking(key) + + assert result is not None + assert isinstance(result, MemoryObj) + assert result.metadata.shape == memory_obj.metadata.shape + assert result.metadata.dtype == memory_obj.metadata.dtype + + local_disk_backend.local_cpu_backend.memory_allocator.close() + + def test_async_save_bytes_to_disk(self, local_disk_backend, async_loop): + """Test async_save_bytes_to_disk().""" + key = create_test_key("test_key") + memory_obj = create_test_memory_obj() + + local_disk_backend.insert_key(key, memory_obj) + + # Check that the key was inserted into the backend + assert key in local_disk_backend.dict + + # Check that the metadata was properly set + metadata = local_disk_backend.dict[key] + assert metadata.path == local_disk_backend._key_to_path(key) + assert metadata.size == memory_obj.get_size() + + local_disk_backend.local_cpu_backend.memory_allocator.close() + + def test_async_load_bytes_from_disk(self, local_disk_backend): + """Test async_load_bytes_from_disk()""" + key = create_test_key("test_key") + memory_obj = create_test_memory_obj() + + # Create the file first + path = local_disk_backend._key_to_path(key) + with open(path, "wb") as f: + f.write(memory_obj.byte_array) + + result = local_disk_backend.load_bytes_from_disk( + path, + memory_obj.metadata.dtype, + memory_obj.metadata.shape, + memory_obj.metadata.fmt, + ) + + assert result is not None + assert isinstance(result, MemoryObj) + assert result.metadata.shape == memory_obj.metadata.shape + assert result.metadata.dtype == memory_obj.metadata.dtype + + local_disk_backend.local_cpu_backend.memory_allocator.close() + + def test_load_bytes_from_disk(self, local_disk_backend): + """Test load_bytes_from_disk().""" + key = create_test_key("test_key") + memory_obj = create_test_memory_obj() + + # Create the file first + path = local_disk_backend._key_to_path(key) + with open(path, "wb") as f: + f.write(memory_obj.byte_array) + + result = local_disk_backend.load_bytes_from_disk( + path, + memory_obj.metadata.dtype, + memory_obj.metadata.shape, + memory_obj.metadata.fmt, + ) + + assert result is not None + assert isinstance(result, MemoryObj) + assert result.metadata.shape == memory_obj.metadata.shape + assert result.metadata.dtype == memory_obj.metadata.dtype + + local_disk_backend.local_cpu_backend.memory_allocator.close() + + def test_close(self, temp_disk_path, async_loop, local_cpu_backend): + """Test close().""" + config = create_test_config(temp_disk_path) + lookup_server = MockLookupServer() + + backend = LocalDiskBackend( + config=config, + loop=async_loop, + local_cpu_backend=local_cpu_backend, + dst_device="cuda", + lookup_server=lookup_server, + ) + + # Add some keys + for i in range(3): + key = create_test_key(f"key_{i}") + memory_obj = create_test_memory_obj() + backend.insert_key(key, memory_obj) + + # Close the backend + backend.close() + + # Check that keys were removed from lookup server + assert len(lookup_server.removed_keys) == 3 + + local_cpu_backend.memory_allocator.close() + + def test_concurrent_access(self, local_disk_backend): + """Test concurrent access to the backend.""" + key = create_test_key("test_key") + memory_obj = create_test_memory_obj() + + # Insert key + local_disk_backend.insert_key(key, memory_obj) + + # Test concurrent contains() calls + def check_contains(): + for _ in range(20): + assert local_disk_backend.contains(key) + + threads = [threading.Thread(target=check_contains) for _ in range(3)] + for thread in threads: + thread.start() + for thread in threads: + thread.join() + + local_disk_backend.local_cpu_backend.memory_allocator.close() + + def test_file_operations_error_handling(self, local_disk_backend): + """Test error handling in file operations.""" + # Test with non-existent file + non_existent_path = "/non/existent/path/file.pt" + + with pytest.raises(FileNotFoundError): + local_disk_backend.load_bytes_from_disk( + non_existent_path, + torch.bfloat16, + torch.Size([2, 16, 8, 128]), + MemoryFormat.KV_T2D, + ) + + local_disk_backend.local_cpu_backend.memory_allocator.close() + + def test_evictor_integration(self, local_disk_backend): + """Test integration with the LRU evictor.""" + # Add multiple keys to test eviction + keys = [] + memory_objs = [] + + for i in range(10): + key = create_test_key(f"key_{i}") + memory_obj = create_test_memory_obj() + keys.append(key) + memory_objs.append(memory_obj) + local_disk_backend.insert_key(key, memory_obj) + + # Test that evictor is working + assert len(local_disk_backend.dict) == 10 + + # The evictor should be managing the cache size + assert local_disk_backend.evictor is not None + + local_disk_backend.local_cpu_backend.memory_allocator.close() + + def test_cleanup_on_remove(self, local_disk_backend): + """Test that resources are properly cleaned up on remove.""" + key = create_test_key("test_key") + memory_obj = create_test_memory_obj() + + # Insert key + local_disk_backend.insert_key(key, memory_obj) + + # Create the file + path = local_disk_backend._key_to_path(key) + with open(path, "wb") as f: + f.write(memory_obj.byte_array) + + # Remove key + local_disk_backend.remove(key) + + # Check that both the dict entry and file are removed + assert key not in local_disk_backend.dict + assert not os.path.exists(path) + + local_disk_backend.local_cpu_backend.memory_allocator.close() + + def test_thread_safety(self, local_disk_backend): + """Test thread safety of the backend.""" + key = create_test_key("test_key") + memory_obj = create_test_memory_obj() + + # Insert key + local_disk_backend.insert_key(key, memory_obj) + + path = local_disk_backend._key_to_path(key) + with open(path, "wb") as f: + f.write(memory_obj.byte_array) + + # Test concurrent operations with reduced iteration count + def concurrent_operations(): + for _ in range(10): + # Test contains + local_disk_backend.contains(key) + # Test pin/unpin + local_disk_backend.pin(key) + local_disk_backend.unpin(key) + # Test get_blocking + result = local_disk_backend.get_blocking(key) + assert result is not None + + threads = [threading.Thread(target=concurrent_operations) for _ in range(3)] + for thread in threads: + thread.start() + for thread in threads: + thread.join() + + # The backend should still be in a consistent state + assert local_disk_backend.contains(key) + + local_disk_backend.local_cpu_backend.memory_allocator.close() diff --git a/tests/v1/test_cache_engine.py b/tests/v1/test_cache_engine.py new file mode 100644 index 0000000..da29e38 --- /dev/null +++ b/tests/v1/test_cache_engine.py @@ -0,0 +1,882 @@ +# SPDX-License-Identifier: Apache-2.0 +# Standard +from copy import deepcopy +import random +import shlex +import subprocess +import time + +# Third Party +import pytest +import torch + +# First Party +from utils import ( + check_paged_kv_cache_equal, + create_gpu_connector, + dumb_metadata, + generate_kv_cache_paged_list_tensors, + generate_tokens, +) +from lmcache.v1.cache_engine import LMCacheEngineBuilder +from lmcache.v1.config import LMCacheEngineConfig + + +def test_paged_same_retrieve_store(autorelease_v1): + device = "cuda" + fmt = "vllm" + num_tokens = 2000 + num_blocks = 1000 + block_size = 16 + dtype = torch.bfloat16 + + chunk_size = 256 + kv_shape = (32, 2, chunk_size, 8, 128) + + connector = create_gpu_connector(1024, 32) + + tokens = generate_tokens(num_tokens, device) + + kv_cache = generate_kv_cache_paged_list_tensors( + num_blocks, device, block_size, dtype + ) + retrieved_cache = generate_kv_cache_paged_list_tensors( + num_blocks, device, block_size, dtype + ) + + original_retrieved_cache = deepcopy(retrieved_cache) + + slot_mapping = random.sample(range(0, num_blocks * block_size), num_tokens) + slot_mapping = torch.tensor(slot_mapping, device=device) + + # Check the kv cache and the retrieval buffer are not the same + check_paged_kv_cache_equal(retrieved_cache, original_retrieved_cache, slot_mapping) + with pytest.raises(AssertionError): + check_paged_kv_cache_equal(retrieved_cache, kv_cache, slot_mapping) + """ initialize the engine """ + cfg = LMCacheEngineConfig.from_legacy(chunk_size=chunk_size, remote_url=None) + + engine = autorelease_v1( + LMCacheEngineBuilder.get_or_create( + "test", cfg, dumb_metadata(fmt, kv_shape), connector + ) + ) + """ test retrieve empty """ + ret_mask = engine.retrieve( + tokens, kvcaches=retrieved_cache, slot_mapping=slot_mapping + ) + length = torch.sum(ret_mask) + assert length == 0 + check_paged_kv_cache_equal(retrieved_cache, original_retrieved_cache, slot_mapping) + """ test store """ + engine.store(tokens=tokens, kvcaches=kv_cache, slot_mapping=slot_mapping) + + """ Store is async. Need to wait for the store to finish """ + timeout = 1.5 + start_time = time.time() + while engine.lookup(tokens) < num_tokens: + if time.time() - start_time > timeout: + raise TimeoutError(f"Operation timed out after {timeout} seconds.") + time.sleep(0.01) + """ test retrieve """ + ret_mask = engine.retrieve( + tokens, kvcaches=retrieved_cache, slot_mapping=slot_mapping + ) + length = torch.sum(ret_mask) + assert length == num_tokens + check_paged_kv_cache_equal(retrieved_cache, kv_cache, slot_mapping) + + +# TODO (Gingfung): once cachegen add back remote_cachegen +@pytest.mark.parametrize("fmt", ["vllm"]) +@pytest.mark.parametrize("chunk_size", [128, 256]) +@pytest.mark.parametrize("backend", ["cpu", "local_disk", "remote"]) +@pytest.mark.parametrize("lmserver_v1_process", ["cpu"], indirect=True) +def test_paged_retrieve_prefix( + fmt, chunk_size, backend, lmserver_v1_process, autorelease_v1 +): + url = None + remote_serde = None + check_equality = True + if "remote" in backend: + url = lmserver_v1_process.server_url + if backend == "remote_cachegen": + backend = "remote" + remote_serde = "cachegen" + check_equality = False + else: + remote_serde = "naive" + device = "cuda" + num_tokens = 2000 + new_num_tokens = 1000 + kv_shape = (32, 2, chunk_size, 8, 128) + num_blocks = 1000 + block_size = 16 + dtype = torch.bfloat16 + connector = create_gpu_connector(1024, 32) + + tokens = generate_tokens(num_tokens, device) + kv_cache = generate_kv_cache_paged_list_tensors( + num_blocks, device, block_size, dtype + ) + new_tokens = generate_tokens(new_num_tokens, device) + retrieved_cache = generate_kv_cache_paged_list_tensors( + num_blocks, device, block_size, dtype + ) + slot_mapping_full = random.sample( + range(0, num_blocks * block_size), num_tokens + new_num_tokens + ) + slot_mapping = torch.tensor(slot_mapping_full[:num_tokens], device=device) + + new_slot_mapping = torch.tensor(slot_mapping_full[-new_num_tokens:], device=device) + """ initialize the engine """ + cfg = LMCacheEngineConfig.from_legacy( + chunk_size=chunk_size, + backend=backend, + remote_url=url, + remote_serde=remote_serde, + ) + + engine = autorelease_v1( + LMCacheEngineBuilder.get_or_create( + "test", cfg, dumb_metadata(fmt, kv_shape), connector + ) + ) + """ test store """ + t1 = time.perf_counter() + engine.store(tokens, kvcaches=kv_cache, slot_mapping=slot_mapping) + t2 = time.perf_counter() + print(f"store {len(tokens)} takes {t2 - t1}") + """ Compute expected length """ + expected_chunk_cnt = num_tokens // chunk_size + expected_length = expected_chunk_cnt * chunk_size + """ Store is async. Need to wait for the store to finish """ + if backend == "cpu": + timeout = 1 + search_range = "LocalCPUBackend" + elif backend == "local_disk": + timeout = 30 + search_range = "LocalDiskBackend" + elif backend == "remote": + timeout = 30 + search_range = "RemoteBackend" + start_time = time.time() + while engine.lookup(tokens, search_range) < expected_length: + if time.time() - start_time > timeout: + raise TimeoutError(f"Operation timed out after {timeout} seconds.") + time.sleep(0.01) + """ test retrieve """ + t4 = time.perf_counter() + ret_mask = engine.retrieve( + torch.cat([tokens, new_tokens]), + kvcaches=retrieved_cache, + slot_mapping=torch.cat([slot_mapping, new_slot_mapping]), + ) + + length = torch.sum(ret_mask) + t5 = time.perf_counter() + print(f"retrieve {length} takes {t5 - t4}") + + assert length == expected_length + + if check_equality: + check_paged_kv_cache_equal( + kv_cache, + retrieved_cache, + torch.cat([slot_mapping, new_slot_mapping])[:expected_length], + ) + + if backend in ["local_disk"]: + subprocess.run(shlex.split("rm -rf local/disk_test/local_disk/")) + + +@pytest.mark.parametrize("fmt", ["vllm"]) +@pytest.mark.parametrize("chunk_size", [256]) +@pytest.mark.parametrize( + "backend", + ["cpu", "local_disk", "remote"], +) +@pytest.mark.parametrize("lmserver_v1_process", ["cpu"], indirect=True) +def test_paged_store_offset( + fmt, chunk_size, backend, lmserver_v1_process, autorelease_v1 +): + url = None + if backend == "remote": + url = lmserver_v1_process.server_url + device = "cuda" + num_tokens = 2000 + num_suffix_tokens = 500 + num_total_tokens = 3000 + kv_shape = (32, 2, chunk_size, 8, 128) + num_blocks = 1000 + block_size = 16 + dtype = torch.bfloat16 + connector = create_gpu_connector(1024, 32) + + tokens = generate_tokens(num_total_tokens, device) + kv_cache = generate_kv_cache_paged_list_tensors( + num_blocks, device, block_size, dtype + ) + retrieved_cache = generate_kv_cache_paged_list_tensors( + num_blocks, device, block_size, dtype + ) + slot_mapping = random.sample(range(0, num_blocks * block_size), num_total_tokens) + slot_mapping = torch.tensor(slot_mapping, device=device) + + """ initialize the engine """ + cfg = LMCacheEngineConfig.from_legacy( + chunk_size=chunk_size, backend=backend, remote_url=url + ) + + engine = autorelease_v1( + LMCacheEngineBuilder.get_or_create( + "test", cfg, dumb_metadata(fmt, kv_shape), connector + ) + ) + """ test store """ + engine.store( + tokens[:num_tokens], + kvcaches=kv_cache, + slot_mapping=slot_mapping[:num_tokens], + ) + + offset_chunk_cnt = num_tokens // chunk_size + offset_length = offset_chunk_cnt * chunk_size + mask = torch.ones(num_tokens + num_suffix_tokens, device=device) + mask[:offset_length] = 0 + engine.store( + tokens[: num_tokens + num_suffix_tokens], + kvcaches=kv_cache, + mask=mask, + slot_mapping=slot_mapping[: num_tokens + num_suffix_tokens], + ) + """ Compute expected length """ + expected_chunk_cnt = (num_tokens + num_suffix_tokens) // chunk_size + expected_length = expected_chunk_cnt * chunk_size + """ Store is async. Need to wait for the store to finish """ + if backend == "cpu": + timeout = 1 + elif backend == "local_disk": + timeout = 30 + start_time = time.time() + while engine.lookup(tokens[: num_tokens + num_suffix_tokens]) < expected_length: + if time.time() - start_time > timeout: + raise TimeoutError(f"Operation timed out after {timeout} seconds.") + time.sleep(0.01) + """ test retrieve """ + t4 = time.perf_counter() + ret_mask = engine.retrieve( + tokens, kvcaches=retrieved_cache, slot_mapping=slot_mapping + ) + + length = torch.sum(ret_mask) + t5 = time.perf_counter() + print(f"retrieve {length} takes {t5 - t4}") + + assert length == expected_length + check_paged_kv_cache_equal( + kv_cache, + retrieved_cache, + slot_mapping[:expected_length], + ) + + if backend in ["local_disk"]: + subprocess.run(shlex.split("rm -rf local/disk_test/local_disk/")) + + +@pytest.mark.parametrize("fmt", ["vllm"]) +@pytest.mark.parametrize("chunk_size", [128]) # , 256]) +@pytest.mark.parametrize( + "backend", + [ + # "cpu", + "local_disk" + ], +) +def test_paged_mixed_retrieve(fmt, chunk_size, backend, autorelease_v1): + device = "cuda" + num_tokens = 2000 + new_num_tokens = 1000 + num_blocks = 1000 + block_size = 16 + dtype = torch.bfloat16 + + kv_shape = (32, 2, chunk_size, 8, 128) + connector = create_gpu_connector(1024, 32) + + tokens = generate_tokens(num_tokens, device) + kv_cache = generate_kv_cache_paged_list_tensors( + num_blocks, device, block_size, dtype + ) + new_tokens = generate_tokens(new_num_tokens, device) + retrieved_cache = generate_kv_cache_paged_list_tensors( + num_blocks, device, block_size, dtype + ) + + slot_mapping_full = random.sample( + range(0, num_blocks * block_size), num_tokens + new_num_tokens + ) + slot_mapping = torch.tensor(slot_mapping_full[:num_tokens], device=device) + + new_slot_mapping = torch.tensor(slot_mapping_full[-new_num_tokens:], device=device) + + """ initialize the engine """ + cfg = LMCacheEngineConfig.from_legacy(chunk_size=chunk_size, backend=backend) + + engine = autorelease_v1( + LMCacheEngineBuilder.get_or_create( + "test", cfg, dumb_metadata(fmt, kv_shape), connector + ) + ) + """ test store """ + engine.store(tokens, kvcaches=kv_cache, slot_mapping=slot_mapping) + engine.store(new_tokens, kvcaches=kv_cache, slot_mapping=new_slot_mapping) + """ Store is async. Need to wait for the store to finish """ + expected_chunk_cnt = num_tokens // chunk_size + expected_length = expected_chunk_cnt * chunk_size + if backend == "cpu": + timeout = 1 + search_range = "LocalCPUBackend" + elif backend == "local_disk": + timeout = 30 + search_range = "LocalDiskBackend" + start_time = time.time() + while engine.lookup(tokens, search_range) < expected_length: + if time.time() - start_time > timeout: + raise TimeoutError(f"Operation timed out after {timeout} seconds.") + time.sleep(0.01) + """ test retrieve """ + ret_mask = engine.retrieve( + torch.cat([tokens, new_tokens]), + kvcaches=retrieved_cache, + slot_mapping=torch.cat([slot_mapping, new_slot_mapping]), + ) + length = torch.sum(ret_mask) + assert length == expected_length + check_paged_kv_cache_equal( + retrieved_cache, + kv_cache, + torch.cat([slot_mapping, new_slot_mapping])[:expected_length], + ) + + """Wait for store to finish""" + expected_length = new_num_tokens + start_time = time.time() + while engine.lookup(new_tokens, search_range) < expected_length: + if time.time() - start_time > timeout: + raise TimeoutError(f"Operation timed out after {timeout} seconds.") + time.sleep(0.01) + """ test another retrieve """ + ret_mask = engine.retrieve( + new_tokens, kvcaches=retrieved_cache, slot_mapping=new_slot_mapping + ) + length = torch.sum(ret_mask) + assert length == expected_length + check_paged_kv_cache_equal( + retrieved_cache, kv_cache, new_slot_mapping[:expected_length] + ) + + """ insert the mixed kv cache """ + final_tokens = torch.cat([tokens, new_tokens]) + engine.store( + final_tokens, + kvcaches=kv_cache, + slot_mapping=torch.cat([slot_mapping, new_slot_mapping]), + ) + + """Wait until store finishes""" + expected_length = num_tokens + new_num_tokens + start_time = time.time() + while ( + engine.lookup(torch.cat([tokens, new_tokens]), search_range) < expected_length + ): + if time.time() - start_time > timeout: + raise TimeoutError(f"Operation timed out after {timeout} seconds.") + time.sleep(0.01) + """ should retrieve the mixed version """ + retrieved_cache = generate_kv_cache_paged_list_tensors( + num_blocks, device, block_size, dtype + ) + ret_mask = engine.retrieve( + final_tokens, + kvcaches=retrieved_cache, + slot_mapping=torch.cat([slot_mapping, new_slot_mapping]), + ) + length = torch.sum(ret_mask) + assert length == expected_length + + check_paged_kv_cache_equal( + retrieved_cache, + kv_cache, + slot_mapping=torch.cat([slot_mapping, new_slot_mapping]), + ) + """destroy local disk path""" + if backend in ["local_disk"]: + subprocess.run(shlex.split("rm -rf local/disk_test/local_disk/")) + + +@pytest.mark.parametrize("fmt", ["vllm"]) +def test_paged_store_kv_tensors_mask(fmt, autorelease_v1): + device = "cuda" + num_tokens = 1000 + new_num_tokens = 2000 + num_blocks = 1000 + block_size = 16 + dtype = torch.bfloat16 + + chunk_size = 256 + kv_shape = (32, 2, chunk_size, 8, 128) + connector = create_gpu_connector(1024, 32) + + tokens = generate_tokens(num_tokens, device) + kv_cache = generate_kv_cache_paged_list_tensors( + num_blocks, device, block_size, dtype=dtype + ) + + new_tokens = generate_tokens(new_num_tokens, device) + final_tokens = torch.cat([tokens, new_tokens]) + + slot_mapping_full = random.sample( + range(0, num_blocks * block_size), num_tokens + new_num_tokens + ) + slot_mapping = torch.tensor(slot_mapping_full[:num_tokens], device=device) + + new_slot_mapping = torch.tensor(slot_mapping_full[-new_num_tokens:], device=device) + + cfg = LMCacheEngineConfig.from_legacy(chunk_size=chunk_size) + + engine = autorelease_v1( + LMCacheEngineBuilder.get_or_create( + "test", cfg, dumb_metadata(fmt, kv_shape), connector + ) + ) + """ Store some tokens with mask """ + engine.store(tokens, kvcaches=kv_cache, slot_mapping=slot_mapping) + """Wait until store finishes""" + timeout = 1 + start_time = time.time() + while engine.lookup(tokens) < num_tokens: + if time.time() - start_time > timeout: + raise TimeoutError(f"Operation timed out after {timeout} seconds.") + time.sleep(0.01) + + prefix_length = engine.lookup(tokens) + assert prefix_length == num_tokens, ( + f"Expected {num_tokens} prefix tokens, but got {prefix_length}" + ) + """ Store more tokens """ + prefix_length = engine.lookup(final_tokens) + kv_tensor_mask = torch.ones_like(final_tokens, dtype=torch.bool) + kv_tensor_mask[:prefix_length] = False + + engine.store( + final_tokens, + mask=kv_tensor_mask, + kvcaches=kv_cache, + slot_mapping=torch.cat([slot_mapping, new_slot_mapping]), + ) + """Wait until store finishes""" + start_time = time.time() + while engine.lookup(final_tokens) < num_tokens + new_num_tokens: + if time.time() - start_time > timeout: + raise TimeoutError(f"Operation timed out after {timeout} seconds.") + time.sleep(0.01) + + prefix_length = engine.lookup(final_tokens) + assert prefix_length == num_tokens + new_num_tokens, ( + f"Expected {num_tokens + new_num_tokens} prefix tokens, but got {prefix_length}" + ) + """ retrieve the whole cache """ + retrieved_cache = generate_kv_cache_paged_list_tensors( + num_blocks, device, block_size, dtype=dtype + ) + ret_mask = engine.retrieve( + final_tokens, + kvcaches=retrieved_cache, + slot_mapping=torch.cat([slot_mapping, new_slot_mapping]), + ) + length = torch.sum(ret_mask) + expected_length = num_tokens + new_num_tokens + assert length == expected_length + check_paged_kv_cache_equal( + retrieved_cache, + kv_cache, + torch.cat([slot_mapping, new_slot_mapping])[:expected_length], + ) + + """ retrieve cache with some mask: + """ + num_falses = chunk_size * 3 + mask = torch.ones_like(final_tokens, dtype=torch.bool) + mask[:num_falses] = False + retrieved_cache = generate_kv_cache_paged_list_tensors( + num_blocks, device, block_size, dtype=dtype + ) + ret_mask = engine.retrieve( + final_tokens, + mask=mask, + kvcaches=retrieved_cache, + slot_mapping=torch.cat([slot_mapping, new_slot_mapping]), + ) + length = torch.sum(ret_mask) + full_length = num_tokens + new_num_tokens + expected_length = full_length - num_falses + assert length == expected_length + + with pytest.raises(AssertionError): + check_paged_kv_cache_equal( + retrieved_cache, + kv_cache, + torch.cat([slot_mapping, new_slot_mapping])[:full_length], + ) + check_paged_kv_cache_equal( + retrieved_cache, + kv_cache, + torch.cat([slot_mapping, new_slot_mapping])[num_falses:full_length], + ) + + mask[: num_falses + 5] = False + with pytest.raises(ValueError): + engine.retrieve( + final_tokens, + mask=mask, + kvcaches=retrieved_cache, + slot_mapping=torch.cat([slot_mapping, new_slot_mapping]), + ) + + +@pytest.mark.parametrize("fmt", ["vllm"]) +@pytest.mark.parametrize("chunk_size", [128]) +@pytest.mark.parametrize( + "backend", + [ + "local_cpu_disk_remote", + ], +) +@pytest.mark.parametrize( + "retrieve_from", + [ + "local_cpu", + "local_disk", + "remote", + ], +) +@pytest.mark.parametrize("lmserver_v1_process", ["cpu"], indirect=True) +def test_paged_hierarchy_retrieve( + fmt, chunk_size, backend, retrieve_from, lmserver_v1_process, autorelease_v1 +): + url = None + if backend == "local_cpu_disk_remote": + url = lmserver_v1_process.server_url + device = "cuda" + num_tokens = 2000 + new_num_tokens = 1000 + kv_shape = (32, 2, chunk_size, 8, 128) + num_blocks = 1000 + block_size = 16 + dtype = torch.bfloat16 + + connector = create_gpu_connector(1024, 32) + + tokens = generate_tokens(num_tokens, device) + kv_cache = generate_kv_cache_paged_list_tensors( + num_blocks, device, block_size, dtype=dtype + ) + + new_tokens = generate_tokens(new_num_tokens, device) + retrieved_cache = generate_kv_cache_paged_list_tensors( + num_blocks, device, block_size, dtype=dtype + ) + + slot_mapping = random.sample( + range(0, num_blocks * block_size), num_tokens + new_num_tokens + ) + slot_mapping = torch.tensor(slot_mapping[:num_tokens], device=device) + + new_slot_mapping = torch.tensor(slot_mapping[-new_num_tokens:], device=device) + + """ initialize the engine """ + cfg = LMCacheEngineConfig.from_legacy( + chunk_size=chunk_size, backend=backend, remote_url=url + ) + + engine = autorelease_v1( + LMCacheEngineBuilder.get_or_create( + "test", cfg, dumb_metadata(fmt, kv_shape), connector + ) + ) + """ test store """ + t1 = time.perf_counter() + engine.store(tokens, kvcaches=kv_cache, slot_mapping=slot_mapping) + t2 = time.perf_counter() + print(f"store {len(tokens)} takes {t2 - t1}") + """ Compute expected length """ + expected_chunk_cnt = num_tokens // chunk_size + expected_length = expected_chunk_cnt * chunk_size + """ Store is async. Need to wait for the store to finish """ + timeout = 1 + start_time = time.time() + while engine.lookup(tokens) < expected_length: + if time.time() - start_time > timeout: + raise TimeoutError(f"Operation timed out after {timeout} seconds.") + time.sleep(0.01) + """ Wait until disk save is finished """ + if retrieve_from in ["local_disk", "remote"]: + engine.storage_manager.clear(locations=["LocalCPUBackend"]) + timeout = 30 + start_time = time.time() + while engine.lookup(tokens, ["LocalDiskBackend"]) < expected_length: + if time.time() - start_time > timeout: + raise TimeoutError(f"Operation timed out after {timeout} seconds.") + time.sleep(0.01) + """ Wait until remote save is finished """ + if retrieve_from == "remote": + engine.storage_manager.clear(locations=["LocalCPUBackend"]) + # FIXME: change this `clear` + engine.storage_manager.storage_backends["LocalDiskBackend"].dict.clear() + timeout = 30 + start_time = time.time() + while engine.lookup(tokens, ["RemoteBackend"]) < expected_length: + if time.time() - start_time > timeout: + raise TimeoutError(f"Operation timed out after {timeout} seconds.") + time.sleep(0.01) + """ test retrieve """ + t4 = time.perf_counter() + ret_mask = engine.retrieve( + torch.cat([tokens, new_tokens]), + kvcaches=retrieved_cache, + slot_mapping=torch.cat([slot_mapping, new_slot_mapping]), + ) + + length = torch.sum(ret_mask) + t5 = time.perf_counter() + print(f"retrieve {length} takes {t5 - t4}") + + assert length == expected_length + check_paged_kv_cache_equal( + retrieved_cache, + kv_cache, + torch.cat([slot_mapping, new_slot_mapping])[:expected_length], + ) + + """ Wait until disk save is finished before deleting the directory""" + if backend in ["local_cpu_disk"]: + engine.storage_manager.clear(locations=["LocalCPUBackend"]) + timeout = 30 + start_time = time.time() + while engine.lookup(tokens) < expected_length: + if time.time() - start_time > timeout: + raise TimeoutError(f"Operation timed out after {timeout} seconds.") + time.sleep(0.01) + + if backend in ["local_cpu_disk"]: + subprocess.run(shlex.split("rm -rf local/disk_test/local_disk/")) + + +@pytest.mark.parametrize( + "backend", + [ + "local_cpu_disk", + ], +) +@pytest.mark.parametrize( + "prefetch_from", + [ + "local_disk", + ], +) +def test_paged_prefetch_retrieve(backend, prefetch_from, autorelease_v1): + device = "cuda" + num_tokens = 2000 + new_num_tokens = 1000 + num_blocks = 1000 + block_size = 16 + dtype = torch.bfloat16 + + chunk_size = 256 + fmt = "vllm" + kv_shape = (32, 2, chunk_size, 8, 128) + connector = create_gpu_connector(1024, 32) + + tokens = generate_tokens(num_tokens, device) + kv_cache = generate_kv_cache_paged_list_tensors( + num_blocks, device, block_size, dtype=dtype + ) + new_tokens = generate_tokens(new_num_tokens, device) + retrieved_cache = generate_kv_cache_paged_list_tensors( + num_blocks, device, block_size, dtype=dtype + ) + + slot_mapping = random.sample( + range(0, num_blocks * block_size), num_tokens + new_num_tokens + ) + slot_mapping = torch.tensor(slot_mapping[:num_tokens], device=device) + + new_slot_mapping = torch.tensor(slot_mapping[-new_num_tokens:], device=device) + + """ initialize the engine """ + cfg = LMCacheEngineConfig.from_legacy(chunk_size=chunk_size, backend=backend) + + engine = autorelease_v1( + LMCacheEngineBuilder.get_or_create( + "test", cfg, dumb_metadata(fmt, kv_shape), connector + ) + ) + """ test store """ + t1 = time.perf_counter() + engine.store(tokens, kvcaches=kv_cache, slot_mapping=slot_mapping) + t2 = time.perf_counter() + print(f"store {len(tokens)} takes {t2 - t1}") + """ Compute expected length """ + expected_chunk_cnt = num_tokens // chunk_size + expected_length = expected_chunk_cnt * chunk_size + """ Wait for cpu store to finish """ + timeout = 1 + start_time = time.time() + while engine.lookup(tokens) < expected_length: + if time.time() - start_time > timeout: + raise TimeoutError(f"Operation timed out after {timeout} seconds.") + time.sleep(0.01) + """ Delete cpu cache and wait until disk save finishes.""" + if prefetch_from == "local_disk": + engine.storage_manager.clear(locations=["LocalCPUBackend"]) + timeout = 30 + start_time = time.time() + while engine.lookup(tokens) < expected_length: + if time.time() - start_time > timeout: + raise TimeoutError(f"Operation timed out after {timeout} seconds.") + time.sleep(0.1) + """ Wait until disk load (prefetch) finishes and delete disk cache""" + engine.prefetch(torch.cat([tokens, new_tokens])) + + if prefetch_from == "local_disk": + timeout = 60 + start_time = time.time() + while ( + engine.lookup(torch.cat([tokens, new_tokens]), ["LocalCPUBackend"]) + < expected_length + ): + if time.time() - start_time > timeout: + raise TimeoutError(f"Operation timed out after {timeout} seconds.") + time.sleep(0.01) + engine.storage_manager.storage_backends["LocalDiskBackend"].dict.clear() + """ test retrieve """ + t4 = time.perf_counter() + ret_mask = engine.retrieve( + torch.cat([tokens, new_tokens]), + kvcaches=retrieved_cache, + slot_mapping=torch.cat([slot_mapping, new_slot_mapping]), + ) + + length = torch.sum(ret_mask) + t5 = time.perf_counter() + print(f"retrieve {length} takes {t5 - t4}") + + assert length == expected_length + check_paged_kv_cache_equal( + retrieved_cache, + kv_cache, + torch.cat([slot_mapping, new_slot_mapping])[:expected_length], + ) + + if backend in ["local_cpu_disk"]: + subprocess.run(shlex.split("rm -rf local/disk_test/local_disk/")) + + +@pytest.mark.parametrize("fmt", ["vllm"]) +@pytest.mark.parametrize("chunk_size", [128]) +@pytest.mark.parametrize( + "backend", + [ + "cpu", + "local_disk", + "remote", + "local_disk_remote", + "local_cpu_disk_remote", + ], +) +@pytest.mark.parametrize("lmserver_v1_process", ["cpu"], indirect=True) +def test_paged_mem_leak(fmt, chunk_size, backend, lmserver_v1_process, autorelease_v1): + url = None + if "remote" in backend: + url = lmserver_v1_process.server_url + + device = "cuda" + num_tokens = 2000 + kv_shape = (32, 2, chunk_size, 8, 128) + num_blocks = 1000 + block_size = 16 + dtype = torch.bfloat16 + connector = create_gpu_connector(1024, 32) + + tokens = generate_tokens(num_tokens, device) + kv_cache = generate_kv_cache_paged_list_tensors( + num_blocks, device, block_size, dtype + ) + slot_mapping = random.sample(range(0, num_blocks * block_size), num_tokens) + slot_mapping = torch.tensor(slot_mapping, device=device) + """ initialize the engine """ + cfg = LMCacheEngineConfig.from_legacy( + chunk_size=chunk_size, backend=backend, remote_url=url + ) + + engine = autorelease_v1( + LMCacheEngineBuilder.get_or_create( + "test", cfg, dumb_metadata(fmt, kv_shape), connector + ) + ) + + engine.store(tokens, kvcaches=kv_cache, slot_mapping=slot_mapping) + + expected_length = 2000 + timeout = 30 + """Wait until cpu store finishes""" + if "cpu" in backend: + start_time = time.time() + while engine.lookup(tokens, ["LocalCPUBackend"]) < expected_length: + if time.time() - start_time > timeout: + raise TimeoutError(f"Operation timed out after {timeout} seconds.") + time.sleep(0.01) + """Wait until disk store finishes""" + if "disk" in backend: + start_time = time.time() + while engine.lookup(tokens, ["LocalDiskBackend"]) < expected_length: + if time.time() - start_time > timeout: + raise TimeoutError(f"Operation timed out after {timeout} seconds.") + time.sleep(0.01) + + if "remote" in backend: + start_time = time.time() + while engine.lookup(tokens, ["RemoteBackend"]) < expected_length: + if time.time() - start_time > timeout: + raise TimeoutError(f"Operation timed out after {timeout} seconds.") + time.sleep(0.01) + tensor_memory_allocator = ( + engine.storage_manager.allocator_backend.memory_allocator.pin_allocator + ) + if "cpu" not in backend: + assert tensor_memory_allocator.total_allocated_size == 0 + else: + assert tensor_memory_allocator.total_allocated_size > 0 + + if "disk" in backend: + subprocess.run(shlex.split("rm -rf local/disk_test/local_disk/")) + + +def test_builder(autorelease_v1): + instance_id = "test" + cfg = LMCacheEngineConfig.from_legacy(chunk_size=256) + cfg2 = LMCacheEngineConfig.from_legacy(chunk_size=512) + connector = None + should_be_none = LMCacheEngineBuilder.get(instance_id) + assert should_be_none is None + + _engine = autorelease_v1( + LMCacheEngineBuilder.get_or_create(instance_id, cfg, dumb_metadata(), connector) + ) + _engine2 = autorelease_v1(LMCacheEngineBuilder.get(instance_id)) # noqa + + with pytest.raises(ValueError): + LMCacheEngineBuilder.get_or_create( + instance_id, cfg2, dumb_metadata(), connector + ) diff --git a/tests/v1/test_connector.py b/tests/v1/test_connector.py new file mode 100644 index 0000000..6e6bc0f --- /dev/null +++ b/tests/v1/test_connector.py @@ -0,0 +1,256 @@ +# SPDX-License-Identifier: Apache-2.0 +# Standard +from pathlib import Path +import asyncio +import tempfile + +# Third Party +from utils import ( + check_mem_obj_equal, + close_asyncio_loop, + dumb_cache_engine_key, + init_asyncio_loop, +) +import pytest +import torch + +# First Party +from lmcache_ascend.v1.memory_management import ( + AscendPinMemoryAllocator as PinMemoryAllocator, +) +from lmcache.v1.storage_backend.connector import CreateConnector + + +@pytest.mark.parametrize("lmserver_v1_process", ["cpu"], indirect=True) +@pytest.mark.parametrize( + "url", + [ + "lm://localhost:65000", + ], +) +def test_lm_connector(url, autorelease_v1, lmserver_v1_process): + if url.startswith("lm"): + url = lmserver_v1_process.server_url + + async_loop, async_thread = init_asyncio_loop() + memory_allocator = PinMemoryAllocator(1024 * 1024 * 1024) + connector = autorelease_v1(CreateConnector(url, async_loop, memory_allocator)) + + random_key = dumb_cache_engine_key() + future = asyncio.run_coroutine_threadsafe(connector.exists(random_key), async_loop) + assert not future.result() + + num_tokens = 1000 + mem_obj_shape = [2, 32, num_tokens, 1024] + dtype = torch.bfloat16 + memory_obj = memory_allocator.allocate(mem_obj_shape, dtype) + memory_obj.ref_count_up() + + torch.manual_seed(42) + test_tensor = torch.randint(0, 100, memory_obj.raw_data.shape, dtype=torch.int64) + memory_obj.raw_data.copy_(test_tensor.to(torch.float32).to(dtype)) + + future = asyncio.run_coroutine_threadsafe( + connector.put(random_key, memory_obj), async_loop + ) + future.result() + + future = asyncio.run_coroutine_threadsafe(connector.exists(random_key), async_loop) + assert future.result() + assert memory_obj.get_ref_count() == 1 + + future = asyncio.run_coroutine_threadsafe(connector.get(random_key), async_loop) + retrieved_memory_obj = future.result() + + check_mem_obj_equal( + [retrieved_memory_obj], + [memory_obj], + ) + + close_asyncio_loop(async_loop, async_thread) + + memory_allocator.close() + + +@pytest.mark.parametrize("lmserver_v1_process", ["cpu"], indirect=True) +def test_fs_connector(lmserver_v1_process, autorelease_v1): + """Test filesystem connector: exists, put, get, list, and file store.""" + + with tempfile.TemporaryDirectory() as temp_dir: + # Setup + url = f"fs://host:0/{temp_dir}/" + async_loop, async_thread = init_asyncio_loop() + memory_allocator = PinMemoryAllocator(1024 * 1024 * 1024) + connector = autorelease_v1(CreateConnector(url, async_loop, memory_allocator)) + random_key = dumb_cache_engine_key() + + # Test 1: Verify key doesn't exist initially + future = asyncio.run_coroutine_threadsafe( + connector.exists(random_key), async_loop + ) + assert not future.result() + + # Test 2: Create and store test data + dtype = torch.bfloat16 + memory_obj = memory_allocator.allocate([2, 32, 1000, 1024], dtype) + memory_obj.ref_count_up() + # Fill with deterministic test data + torch.manual_seed(42) + test_tensor = torch.randint( + 0, 100, memory_obj.raw_data.shape, dtype=torch.int64 + ) + memory_obj.raw_data.copy_(test_tensor.to(torch.float32).to(dtype)) + + future = asyncio.run_coroutine_threadsafe( + connector.put(random_key, memory_obj), async_loop + ) + future.result() + + # Test 3: Verify key exists after putting data + future = asyncio.run_coroutine_threadsafe( + connector.exists(random_key), async_loop + ) + assert future.result() + assert memory_obj.get_ref_count() == 1 + + # Test 4: Retrieve and verify data + future = asyncio.run_coroutine_threadsafe(connector.get(random_key), async_loop) + check_mem_obj_equal([future.result()], [memory_obj]) + + # Test 5: List the keys + future = asyncio.run_coroutine_threadsafe(connector.list(), async_loop) + assert future.result() == [random_key.to_string()] + + # Test 6: Verify file existence and format + files = list(Path(temp_dir).glob("*.data")) + assert len(files) == 1 + assert files[0].name == f"{random_key.to_string()}.data" + + close_asyncio_loop(async_loop, async_thread) + + memory_allocator.close() + + +@pytest.mark.parametrize( + "url", + [ + "redis://localhost:6379", + "redis://user:password@localhost:6379/0", + "redis://:password@localhost:6379/1", + "rediss://user:password@localhost:6380?ssl_cert_reqs=CERT_REQUIRED", + "unix:///tmp/redis.sock", + ], +) +def test_redis_connector(url, autorelease_v1): + """Test Redis connector: exists, put, get operations. + + This test uses the MockRedis from conftest.py to simulate + Redis behavior without requiring an actual Redis server. + """ + + async_loop, async_thread = init_asyncio_loop() + memory_allocator = PinMemoryAllocator(1024 * 1024 * 1024) + connector = autorelease_v1(CreateConnector(url, async_loop, memory_allocator)) + + random_key = dumb_cache_engine_key() + + # Test 1: Verify key doesn't exist initially + future = asyncio.run_coroutine_threadsafe(connector.exists(random_key), async_loop) + assert not future.result() + + # Test 2: Create and store test data + num_tokens = 1000 + mem_obj_shape = [2, 32, num_tokens, 1024] + dtype = torch.bfloat16 + memory_obj = memory_allocator.allocate(mem_obj_shape, dtype) + memory_obj.ref_count_up() + + torch.manual_seed(42) + test_tensor = torch.randint(0, 100, memory_obj.raw_data.shape, dtype=torch.int64) + memory_obj.raw_data.copy_(test_tensor.to(torch.float32).to(dtype)) + + # Test 3: Put data + future = asyncio.run_coroutine_threadsafe( + connector.put(random_key, memory_obj), async_loop + ) + future.result() + + # Test 4: Verify key exists after putting data + future = asyncio.run_coroutine_threadsafe(connector.exists(random_key), async_loop) + assert future.result() + assert memory_obj.get_ref_count() == 1 + + # Test 5: Retrieve and verify data + future = asyncio.run_coroutine_threadsafe(connector.get(random_key), async_loop) + retrieved_memory_obj = future.result() + + check_mem_obj_equal( + [retrieved_memory_obj], + [memory_obj], + ) + + close_asyncio_loop(async_loop, async_thread) + + memory_allocator.close() + + +@pytest.mark.parametrize( + "url", + [ + "redis-sentinel://localhost:26379,localhost:26380,localhost:26381", + "redis-sentinel://user:password@localhost:26379,localhost:26380", + "redis-sentinel://localhost:26379", + ], +) +def test_redis_sentinel_connector(url, autorelease_v1): + """Test Redis Sentinel connector: exists, put, get operations. + + This test uses the MockRedisSentinel from conftest.py to simulate + Redis Sentinel behavior without requiring an actual Redis Sentinel setup. + """ + # Standard + import os + + # Set required environment variables for Redis Sentinel + os.environ["REDIS_SERVICE_NAME"] = "mymaster" + os.environ["REDIS_TIMEOUT"] = "5" + + async_loop, async_thread = init_asyncio_loop() + memory_allocator = PinMemoryAllocator(1024 * 1024 * 1024) + connector = autorelease_v1(CreateConnector(url, async_loop, memory_allocator)) + + random_key = dumb_cache_engine_key() + + # Test 1: Verify key doesn't exist initially + future = asyncio.run_coroutine_threadsafe(connector.exists(random_key), async_loop) + assert not future.result() + + # Test 2: Create and store test data + num_tokens = 1000 + mem_obj_shape = [2, 32, num_tokens, 1024] + dtype = torch.bfloat16 + memory_obj = memory_allocator.allocate(mem_obj_shape, dtype) + memory_obj.ref_count_up() + + # Fill with deterministic test data for Redis Sentinel test + torch.manual_seed(123) + test_tensor = torch.randint(0, 100, memory_obj.raw_data.shape, dtype=torch.int64) + memory_obj.raw_data.copy_(test_tensor.to(torch.float32).to(dtype)) + + # Test 3: Put data + future = asyncio.run_coroutine_threadsafe( + connector.put(random_key, memory_obj), async_loop + ) + future.result() + + # Test 4: Verify key exists after putting data + future = asyncio.run_coroutine_threadsafe(connector.exists(random_key), async_loop) + assert future.result() + + # Test 5: Retrieve and verify data + future = asyncio.run_coroutine_threadsafe(connector.get(random_key), async_loop) + future.result() + + close_asyncio_loop(async_loop, async_thread) + + memory_allocator.close() diff --git a/tests/v1/test_mem_kernels.py b/tests/v1/test_mem_kernels.py new file mode 100755 index 0000000..d08ba60 --- /dev/null +++ b/tests/v1/test_mem_kernels.py @@ -0,0 +1,476 @@ +# SPDX-License-Identifier: Apache-2.0 +# Standard +from typing import List +import random + +# Third Party +from utils import ( + check_mem_obj_equal, + check_paged_kv_cache_equal, + generate_kv_cache_paged, + generate_kv_cache_paged_list_tensors, + generate_mla_kv_cache_paged_list_tensors, +) +import pytest +import torch + +# First Party +from lmcache_ascend.v1.memory_management import ( + AscendPinMemoryAllocator as PinMemoryAllocator, +) +import lmcache.c_ops as lmc_ops + + +def _tuple_kv_to_blob( + kv_tensors, +) -> torch.Tensor: + k_temp = [] + v_temp = [] + for kv_layer in kv_tensors: + k_temp.append(kv_layer[0]) + v_temp.append(kv_layer[1]) + k_tensor_blob = torch.stack(k_temp) + v_tensor_blob = torch.stack(v_temp) + + # kv_tensors: [num_layer, 2, num_tok, num_kv_head, head_size] + kv_tensors_flatten = torch.stack((k_tensor_blob, v_tensor_blob)) + kv_tensors_flatten = kv_tensors_flatten.permute([1, 0, 2, 3, 4]) + + return kv_tensors_flatten + + +def _slice_kv_at( + start_idx: int, + kv_tensors: torch.Tensor, + chunk_size: int, +) -> List[torch.Tensor]: + return [ + x.contiguous() + for x in list( + torch.split( + kv_tensors[:, :, start_idx:, ...], + chunk_size, + dim=2, + ) + ) + ] + + +@pytest.mark.parametrize("num_tokens", [256, 500, 1024, 8000]) +def test_extract_and_load_back(num_tokens): + device = "cuda" + + num_blocks = 1000 + block_size = 16 + num_heads = 8 + head_size = 128 + dtype = torch.bfloat16 + kv_cache = generate_kv_cache_paged(num_blocks, device, block_size, dtype) + + slot_mapping = random.sample(range(0, num_blocks * block_size), num_tokens) + slot_mapping = torch.tensor(slot_mapping, device=device) + + pinned_cpu_size = 4 * 1024 * 1024 * 1024 # 4GB + mem_allocator = PinMemoryAllocator(pinned_cpu_size) + + # Old extract + kv_tuple_list = [] + memory_obj_old_list = [] + chunk_size = 256 + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + start_event.record() + for layer_id in range(32): + key_cache = kv_cache[layer_id][0].reshape(-1, num_heads, head_size) + value_cache = kv_cache[layer_id][1].reshape(-1, num_heads, head_size) + kv_tuple_list.append((key_cache[slot_mapping], value_cache[slot_mapping])) + kv_blob = _tuple_kv_to_blob(kv_tuple_list) + kv_chunked = _slice_kv_at(0, kv_blob, chunk_size) + for chunk_id, chunk in enumerate(kv_chunked): + mem_obj_shape = [2, 32, chunk.shape[2], num_heads * head_size] + + memory_obj_old = mem_allocator.allocate(mem_obj_shape, dtype) + chunk = chunk.contiguous() + for layer_id in range(32): + memory_obj_old.tensor[0, layer_id].copy_( + chunk[layer_id, 0].reshape(-1, 1024) + ) + memory_obj_old.tensor[1, layer_id].copy_( + chunk[layer_id, 1].reshape(-1, 1024) + ) + memory_obj_old_list.append(memory_obj_old) + end_event.record() + torch.cuda.synchronize() + elapsed_time_ms = start_event.elapsed_time(end_event) + print("Old extract time: ", elapsed_time_ms / 1000) + + # New extract (zero-copy kernels) + memory_obj_new_list = [] + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + start_event.record() + slot_mapping_chunked = torch.split(slot_mapping, chunk_size) + for chunk_id, slot_mapping_temp in enumerate(slot_mapping_chunked): + mem_obj_shape = [2, 32, len(slot_mapping_temp), num_heads * head_size] + + memory_obj_new = mem_allocator.allocate(mem_obj_shape, dtype) + for layer_id in range(32): + lmc_ops.load_and_reshape_flash( + memory_obj_new.tensor, + kv_cache[layer_id][0], + kv_cache[layer_id][1], + slot_mapping_temp, + layer_id, + ) + memory_obj_new_list.append(memory_obj_new) + end_event.record() + # wait for all the operations to finish + torch.cuda.synchronize() + elapsed_time_ms = start_event.elapsed_time(end_event) + print("New extract time: ", elapsed_time_ms / 1000) + check_mem_obj_equal( + memory_obj_old_list, + memory_obj_new_list, + ) + + # Generate new paged kv_cache + kv_cache_new = generate_kv_cache_paged(num_blocks, device, block_size, dtype) + + # New load back (zero-copy kernels) + for chunk_id, slot_mapping_temp in enumerate(slot_mapping_chunked): + memory_obj_new = memory_obj_new_list[chunk_id] + for layer_id in range(32): + lmc_ops.reshape_and_cache_back_flash( + memory_obj_new.tensor, + kv_cache_new[layer_id][0], + kv_cache_new[layer_id][1], + slot_mapping_temp, + layer_id, + ) + check_paged_kv_cache_equal( + kv_cache, + kv_cache_new, + slot_mapping, + ) + + mem_allocator.close() + + +@pytest.mark.parametrize("num_tokens", [256, 500, 1024, 8000]) +def test_multi_layer_kernel(num_tokens): + device = "cuda" + + num_blocks = 1000 + block_size = 16 + num_heads = 8 + head_size = 128 + chunk_size = 256 + dtype = torch.bfloat16 + kv_cache = generate_kv_cache_paged_list_tensors( + num_blocks, device, block_size, dtype + ) + page_buffer_size = num_blocks * block_size + + slot_mapping = random.sample(range(0, num_blocks * block_size), num_tokens) + slot_mapping = torch.tensor(slot_mapping, device=device) + + pinned_cpu_size = 4 * 1024 * 1024 * 1024 # 4GB + mem_allocator = PinMemoryAllocator(pinned_cpu_size) + + # lmc_ops.multi_layer_kv_transfer(memory_obj_new.tensor, + # kv_cache_pointers, # TODO: initialize this + # slot_mapping_temp, + # kv_cache[0].device, + # len(slot_mapping_temp), True) + + # layer by layer extract + memory_obj_old_list = [] + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + start_event.record() + slot_mapping_chunked = torch.split(slot_mapping, chunk_size) + for chunk_id, slot_mapping_temp in enumerate(slot_mapping_chunked): + mem_obj_shape = [2, 32, len(slot_mapping_temp), num_heads * head_size] + + memory_obj_old = mem_allocator.allocate(mem_obj_shape, dtype) + for layer_id in range(32): + lmc_ops.load_and_reshape_flash( + memory_obj_old.tensor, + kv_cache[layer_id][0], + kv_cache[layer_id][1], + slot_mapping_temp, + layer_id, + ) + memory_obj_old_list.append(memory_obj_old) + end_event.record() + # wait for all the operations to finish + torch.cuda.synchronize() + elapsed_time_ms = start_event.elapsed_time(end_event) + print("Old extract time: ", elapsed_time_ms / 1000) + + # New extract with multi layer kernel + kv_cache_pointers = torch.empty( + 32, dtype=torch.int64, device="cpu", pin_memory=True + ) + for i in range(32): + kv_cache_pointers[i] = kv_cache[i].data_ptr() + + # NOTE (Gingfung): Ascend kernels require kv_cache_pointers to be on dev + kv_cache_pointers = kv_cache_pointers.cuda() + + memory_obj_new_list = [] + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + start_event.record() + slot_mapping_chunked = torch.split(slot_mapping, chunk_size) + for chunk_id, slot_mapping_temp in enumerate(slot_mapping_chunked): + mem_obj_shape = [2, 32, len(slot_mapping_temp), num_heads * head_size] + + memory_obj_new = mem_allocator.allocate(mem_obj_shape, dtype) + lmc_ops.multi_layer_kv_transfer( + memory_obj_new.tensor, + kv_cache_pointers, + slot_mapping_temp, + kv_cache[0].device, + page_buffer_size, + True, + False, + ) + memory_obj_new_list.append(memory_obj_new) + + end_event.record() + # wait for all the operations to finish + torch.cuda.synchronize() + elapsed_time_ms = start_event.elapsed_time(end_event) + print("New extract time: ", elapsed_time_ms / 1000) + + check_mem_obj_equal( + memory_obj_old_list, + memory_obj_new_list, + ) + + # Generate new paged kv_cache + kv_cache_new = generate_kv_cache_paged_list_tensors( + num_blocks, device, block_size, dtype + ) + + kv_cache_pointers_new = torch.empty( + 32, dtype=torch.int64, device="cpu", pin_memory=True + ) + for i in range(32): + kv_cache_pointers_new[i] = kv_cache_new[i].data_ptr() + + # NOTE (Gingfung): Ascend kernels require kv_cache_pointers to be on dev + kv_cache_pointers_new = kv_cache_pointers_new.cuda() + + for chunk_id, slot_mapping_temp in enumerate(slot_mapping_chunked): + memory_obj_new = memory_obj_new_list[chunk_id] + lmc_ops.multi_layer_kv_transfer( + memory_obj_new.tensor, + kv_cache_pointers_new, + slot_mapping_temp, + kv_cache_new[0].device, + page_buffer_size, + False, + False, + ) + + check_paged_kv_cache_equal( + kv_cache, + kv_cache_new, + slot_mapping, + ) + + mem_allocator.close() + + +@pytest.mark.parametrize("num_tokens", [256, 500, 1024, 8000]) +def test_multi_layer_kernel_use_mla(num_tokens): + device = "cuda" + + num_blocks = 1000 + block_size = 64 + head_size = 576 + chunk_size = 256 + dtype = torch.bfloat16 + num_layers = 32 + kv_cache = generate_mla_kv_cache_paged_list_tensors( + num_blocks, device, block_size, dtype, num_layers + ) + + slot_mapping = random.sample(range(0, num_blocks * block_size), num_tokens) + slot_mapping = torch.tensor(slot_mapping, device=device) + + pinned_cpu_size = 4 * 1024 * 1024 * 1024 # 4GB + mem_allocator = PinMemoryAllocator(pinned_cpu_size) + + # layer by layer extract + memory_obj_old_list = [] + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + start_event.record() + slot_mapping_chunked = torch.split(slot_mapping, chunk_size) + for chunk_id, slot_mapping_temp in enumerate(slot_mapping_chunked): + mem_obj_shape = [1, num_layers, len(slot_mapping_temp), head_size] + memory_obj_old = mem_allocator.allocate(mem_obj_shape, dtype) + + for layer_id in range(num_layers): + for token_idx, slot_idx in enumerate(slot_mapping_temp): + slot_idx = slot_idx.item() + + block_idx = slot_idx // block_size + block_offset = slot_idx % block_size + + memory_obj_old.tensor[0][layer_id][token_idx] = kv_cache[layer_id][ + block_idx + ][block_offset] + + memory_obj_old_list.append(memory_obj_old) + end_event.record() + # wait for all the operations to finish + torch.cuda.synchronize() + elapsed_time_ms = start_event.elapsed_time(end_event) + print("Old extract time: ", elapsed_time_ms / 1000) + + # New extract with multi layer kernel + kv_cache_pointers = torch.empty( + num_layers, dtype=torch.int64, device="cpu", pin_memory=True + ) + for i in range(num_layers): + kv_cache_pointers[i] = kv_cache[i].data_ptr() + + # NOTE (Gingfung): Ascend kernels require kv_cache_pointers to be on dev + kv_cache_pointers = kv_cache_pointers.cuda() + + memory_obj_new_list = [] + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + start_event.record() + slot_mapping_chunked = torch.split(slot_mapping, chunk_size) + for chunk_id, slot_mapping_temp in enumerate(slot_mapping_chunked): + mem_obj_shape = [1, num_layers, len(slot_mapping_temp), head_size] + + memory_obj_new = mem_allocator.allocate(mem_obj_shape, dtype) + lmc_ops.multi_layer_kv_transfer( + memory_obj_new.tensor, + kv_cache_pointers, + slot_mapping_temp, + kv_cache[0].device, + 0, + True, + True, + ) + memory_obj_new_list.append(memory_obj_new) + + end_event.record() + # wait for all the operations to finish + torch.cuda.synchronize() + elapsed_time_ms = start_event.elapsed_time(end_event) + print("New extract time: ", elapsed_time_ms / 1000) + + for left_mem_obj, right_mem_obj in zip( + memory_obj_old_list, memory_obj_new_list, strict=False + ): + left_kv, right_kv = left_mem_obj.tensor[0], right_mem_obj.tensor[0] + right_kv = right_kv.to(left_kv.device) + + assert len(left_kv.shape) == 3 + assert len(right_kv.shape) == 3 + + assert (left_kv[:, :, :] == right_kv[:, :, :]).all() + + # Generate new paged kv_cache + kv_cache_new = generate_mla_kv_cache_paged_list_tensors( + num_blocks, device, block_size, dtype, num_layers + ) + + kv_cache_pointers_new = torch.empty( + num_layers, dtype=torch.int64, device="cpu", pin_memory=True + ) + for i in range(num_layers): + kv_cache_pointers_new[i] = kv_cache_new[i].data_ptr() + + # NOTE (Gingfung): Ascend kernels require kv_cache_pointers to be on dev + kv_cache_pointers_new = kv_cache_pointers_new.cuda() + + for chunk_id, slot_mapping_temp in enumerate(slot_mapping_chunked): + memory_obj_new = memory_obj_new_list[chunk_id] + lmc_ops.multi_layer_kv_transfer( + memory_obj_new.tensor, + kv_cache_pointers_new, + slot_mapping_temp, + kv_cache_new[0].device, + 0, + False, + True, + ) + + for left_kv, right_kv in zip(kv_cache, kv_cache_new, strict=False): + assert len(left_kv.shape) == 3 + assert len(right_kv.shape) == 3 + + left_reshaped = left_kv.reshape( + left_kv.shape[0] * left_kv.shape[1], left_kv.shape[2] + ) + right_reshaped = right_kv.reshape( + right_kv.shape[0] * right_kv.shape[1], right_kv.shape[2] + ) + + assert (left_reshaped[slot_mapping, :] == right_reshaped[slot_mapping, :]).all() + + mem_allocator.close() + + +@pytest.mark.parametrize("num_tokens", [256, 500, 1024, 8000]) +@pytest.mark.parametrize("token_major", [True, False]) +def test_single_layer_kernel(num_tokens, token_major): + device = "cuda" + + num_layers = 32 + num_blocks = 1000 + block_size = 16 + num_heads = 8 + head_size = 128 + hidden_dim_size = num_heads * head_size + dtype = torch.bfloat16 + kv_cache = generate_kv_cache_paged_list_tensors( + num_blocks, device, block_size, dtype + ) + kv_cache_new = generate_kv_cache_paged_list_tensors( + num_blocks, device, block_size, dtype + ) + slot_mapping = random.sample(range(0, num_blocks * block_size), num_tokens) + slot_mapping = torch.tensor(slot_mapping, device=device) + + if token_major: + tmp_gpu_buffer = torch.empty( + (num_tokens, 2, hidden_dim_size), dtype=dtype, device=device + ) + else: + tmp_gpu_buffer = torch.empty( + (2, num_tokens, hidden_dim_size), dtype=dtype, device=device + ) + + for layer_id in range(num_layers): + lmc_ops.single_layer_kv_transfer( + tmp_gpu_buffer, + kv_cache[layer_id][0], + kv_cache[layer_id][1], + slot_mapping, + True, + token_major, + ) + lmc_ops.single_layer_kv_transfer( + tmp_gpu_buffer, + kv_cache_new[layer_id][0], + kv_cache_new[layer_id][1], + slot_mapping, + False, + token_major, + ) + + check_paged_kv_cache_equal( + kv_cache, + kv_cache_new, + slot_mapping, + ) diff --git a/tests/v1/test_memory_management.py b/tests/v1/test_memory_management.py new file mode 100644 index 0000000..12129b8 --- /dev/null +++ b/tests/v1/test_memory_management.py @@ -0,0 +1,276 @@ +# SPDX-License-Identifier: Apache-2.0 +# From LMCache +# Third Party +import pytest +import torch + +# First Party +from lmcache.v1.memory_management import ( + BytesBufferMemoryObj, + GPUMemoryAllocator, + HostMemoryAllocator, + MemoryFormat, + PagedTensorMemoryAllocator, + TensorMemoryAllocator, +) + +from lmcache_ascend.v1.memory_management import ( + AscendPinMemoryAllocator as PinMemoryAllocator, + AscendMixedMemoryAllocator as MixedMemoryAllocator, +) + + +def check_allocator(allocator, max_size): + # 512 * 512 * 4 = 1MB + data1 = allocator.allocate([512, 512], torch.float) + assert data1 is not None + assert data1.tensor.dtype == torch.float + assert data1.tensor.shape == (512, 512) + + # 1024 * 1024 * 2 = 2MB + data2 = allocator.allocate([1024, 1024], dtype=torch.bfloat16) + assert data2 is not None + assert data2.tensor.dtype == torch.bfloat16 + assert data2.tensor.shape == (1024, 1024) + + # 2048 * 2048 * 1 = 4MB + data3 = allocator.allocate([2048, 2048], dtype=torch.int8) + assert data3 is not None + assert data3.tensor.dtype == torch.int8 + assert data3.tensor.shape == (2048, 2048) + + allocator.free(data2) + assert data2.tensor is None + assert allocator.memcheck() + + allocator.free(data1) + assert data1.tensor is None + assert allocator.memcheck() + + allocator.free(data2) # This should not crash + + data4 = allocator.allocate([3, 5, 7], dtype=torch.half) + assert data4 is not None + assert data4.tensor.dtype == torch.half + assert data4.tensor.shape == (3, 5, 7) + + data_fail = allocator.allocate([max_size], dtype=torch.float) # This should fail + assert data_fail is None + + assert allocator.memcheck() + + allocator.free(data1) + allocator.free(data2) + allocator.free(data3) + allocator.free(data4) + + assert allocator.memcheck() + + allocator.close() + + +def check_paged_allocator(allocator, shape, dtype, fmt, max_num_pages): + # Allocate one page + data1 = allocator.allocate(shape, dtype, fmt) + assert data1 is not None + assert data1.tensor.dtype == dtype + assert data1.tensor.shape == shape + + # Allocate another 2 pages + data2 = allocator.batched_allocate(shape, dtype, 2, fmt) + + for data in data2: + assert data is not None + assert data.tensor.dtype == dtype + assert data.tensor.shape == shape + + # Allocate a smaller page + smaller_shape = torch.Size([2, 32, 8, 1024]) + data3 = allocator.allocate(smaller_shape, dtype, fmt) + assert data3 is not None + assert data3.tensor.dtype == dtype + assert data3.tensor.shape == smaller_shape + + allocator.free(data3) + assert allocator.memcheck() + + allocator.batched_free(data2) + assert allocator.memcheck() + + allocator.free(data1) + assert allocator.memcheck() + + data_fail = allocator.batched_allocate( + shape, dtype, max_num_pages + 1, fmt + ) # This should fail + assert data_fail is None + + assert allocator.memcheck() + + allocator.close() + + +@pytest.mark.parametrize( + "use_paging", + [True, False], +) +def test_tensor_allocator(use_paging): + total_size = 1024 * 1024 * 128 # 128MB + tensor_buffer = torch.zeros(total_size, dtype=torch.uint8, device="cpu") + if use_paging: + shape = torch.Size([2, 32, 16, 1024]) # 64 pages + dtype = torch.bfloat16 + fmt = MemoryFormat.KV_2LTD + num_pages = 64 + allocator = PagedTensorMemoryAllocator(tensor_buffer, shape, dtype, fmt) + check_paged_allocator(allocator, shape, dtype, fmt, num_pages) + else: + allocator = TensorMemoryAllocator(tensor_buffer) + check_allocator(allocator, total_size) + + allocator.close() + + +@pytest.mark.parametrize( + "alloc_cls", + [ + HostMemoryAllocator, + PinMemoryAllocator, + GPUMemoryAllocator, + MixedMemoryAllocator, + ], +) +@pytest.mark.parametrize( + "use_paging", + [ + False, + True, + ], +) +def test_device_allocators(alloc_cls, use_paging): + total_size = 1024 * 1024 * 128 # 128MB + + shape = torch.Size([2, 32, 16, 1024]) # 64 pages + dtype = torch.bfloat16 + fmt = MemoryFormat.KV_2LTD + + allocator = alloc_cls( + total_size, use_paging=use_paging, shape=shape, dtype=dtype, fmt=fmt + ) + + if use_paging: + num_pages = 64 + check_paged_allocator(allocator, shape, dtype, fmt, num_pages) + else: + check_allocator(allocator, total_size) + + allocator.close() + + +@pytest.mark.parametrize( + "alloc_cls", + [ + HostMemoryAllocator, + PinMemoryAllocator, + GPUMemoryAllocator, + MixedMemoryAllocator, + ], +) +def test_inplace_modification(alloc_cls): + total_size = 1024 + allocator = alloc_cls(total_size) + + data = allocator.allocate([10], torch.float) + assert data is not None + assert data.tensor.dtype == torch.float + assert data.tensor.shape == (10,) + + data.tensor.fill_(1.0) + assert torch.all(data.tensor == 1.0) + + data.tensor[1] = 2.0 + assert data.tensor[1] == 2.0 + + allocator.close() + + +@pytest.mark.parametrize( + "alloc_cls", + [ + HostMemoryAllocator, + PinMemoryAllocator, + GPUMemoryAllocator, + MixedMemoryAllocator, + ], +) +def test_boundary_alloc(alloc_cls): + total_size = 1 << 25 + allocator = alloc_cls(total_size) + data1 = allocator.allocate([512, 10], torch.float) + allocator.allocate([512, 10], torch.float) + allocator.free(data1) + + # `FreeBlock` with size 0 shouldn't exist in the allocator + allocator.allocate([512, 10], torch.float) + + if isinstance(allocator, MixedMemoryAllocator): + assert len(allocator.pin_allocator.explicit_list) == 1 + else: + assert len(allocator.allocator.explicit_list) == 1 + + allocator.close() + + +@pytest.mark.parametrize( + "alloc_cls", + [ + HostMemoryAllocator, + PinMemoryAllocator, + GPUMemoryAllocator, + MixedMemoryAllocator, + ], +) +def test_batched_alloc(alloc_cls): + total_size = 32 * 100 * 2 * 1024 * 2 + batch_size = 32 + allocator = alloc_cls(total_size) + objs = allocator.batched_allocate( + [100, 2, 1024], torch.bfloat16, batch_size, MemoryFormat.KV_T2D + ) + + assert len(objs) == batch_size + for obj in objs: + assert obj is not None + assert obj.tensor is not None + assert obj.tensor.dtype == torch.bfloat16 + assert obj.tensor.shape == (100, 2, 1024) + allocator.batched_free(objs) + + if isinstance(allocator, MixedMemoryAllocator): + assert len(allocator.pin_allocator.explicit_list) == 1 + else: + assert len(allocator.allocator.explicit_list) == 1 + + allocator.close() + + +@pytest.mark.parametrize( + "alloc_cls", + [ + MixedMemoryAllocator, + ], +) +def test_mixed_alloc(alloc_cls): + total_size = 1 << 25 + allocator = alloc_cls(total_size) + data1 = allocator.allocate([512, 0], None, MemoryFormat.BINARY_BUFFER) + allocator.allocate([512, 10], torch.float) + allocator.free(data1) + + assert len(allocator.pin_allocator.explicit_list) == 1 + + assert isinstance(data1, BytesBufferMemoryObj) + + assert len(data1.byte_array) == 512 + + allocator.close() diff --git a/tests/v1/utils.py b/tests/v1/utils.py new file mode 100644 index 0000000..129b56c --- /dev/null +++ b/tests/v1/utils.py @@ -0,0 +1,273 @@ +# SPDX-License-Identifier: Apache-2.0 +# From LMCache +# Standard +import asyncio +import random +import string +import threading + +# Third Party +import torch + +# First Party +from lmcache.config import LMCacheEngineMetadata +from lmcache.utils import CacheEngineKey +from lmcache_ascend.v1.npu_connector import VLLMPagedMemNPUConnectorV2 + + +def dumb_metadata(fmt="vllm", kv_shape=(32, 2, 256, 8, 128)): + return LMCacheEngineMetadata("test_model", 3, 123, fmt, torch.bfloat16, kv_shape) + + +def dumb_metadata_with_model_name( + model_name: str, fmt="vllm", kv_shape=(32, 2, 256, 8, 128) +): + return LMCacheEngineMetadata(model_name, 3, 123, fmt, torch.bfloat16, kv_shape) + + +def dumb_cache_engine_key(): + return CacheEngineKey("vllm", "test_model", 3, 123, 1234) + + +def random_string(N): + return "".join(random.choices(string.ascii_uppercase + string.digits, k=N)) + + +def init_asyncio_loop(): + async_loop = asyncio.new_event_loop() + async_thread = threading.Thread(target=async_loop.run_forever) + async_thread.start() + return async_loop, async_thread + + +def close_asyncio_loop(async_loop, async_thread): + if async_loop.is_running(): + async_loop.call_soon_threadsafe(async_loop.stop) + if async_thread.is_alive(): + async_thread.join() + + +def generate_kv_cache(num_tokens, fmt, device): + ret = [] + num_layers = 32 + num_heads = 8 + head_size = 128 + shape = ( + [num_tokens, num_heads, head_size] + if fmt == "vllm" + else [num_heads, num_tokens, head_size] + ) + dtype = torch.bfloat16 if fmt == "vllm" else torch.float16 + + for i in range(num_layers): + k = torch.rand(shape, dtype=dtype, device=device) + v = torch.rand(shape, dtype=dtype, device=device) + ret.append((k, v)) + + return tuple(ret) + + +def generate_kv_cache_paged_list_tensors( + num_blocks, device, block_size=16, dtype=torch.bfloat16, use_mla=False +): + """ + Instead of Tuple[Tuple[Tensor, Tensor]], return List[Tensor] + where KV are in the same tensor + """ + ret = [] + num_layers = 32 + num_heads = 1 if use_mla else 8 + head_size = 128 + shape = ( + [num_blocks, block_size, head_size] + if use_mla + else [2, num_blocks, block_size, num_heads, head_size] + ) + + for i in range(num_layers): + kv = torch.rand(shape, dtype=dtype, device=device) + ret.append(kv) + + return ret + + +def generate_sglang_kv_cache_paged_list_tensors( + num_layers, + num_blocks, + block_size, + num_heads, + head_size, + use_mla=False, + device="cuda", + dtype=torch.bfloat16, +): + """ + Instead of Tuple[Tuple[Tensor, Tensor]], return List[Tensor] + where KV are in the same tensor + """ + shape = ( + [num_blocks * block_size, 1, head_size] + if use_mla + else [num_blocks * block_size, num_heads, head_size] + ) + if use_mla: + kv_cache = [ + torch.rand(shape, dtype=dtype, device=device) for i in range(num_layers) + ] + else: + k_cache = [ + torch.rand(shape, dtype=dtype, device=device) for i in range(num_layers) + ] + v_cache = [ + torch.rand(shape, dtype=dtype, device=device) for i in range(num_layers) + ] + kv_cache = k_cache + v_cache + return kv_cache + + +def generate_mla_kv_cache_paged_list_tensors( + num_blocks, device, block_size=64, dtype=torch.bfloat16, num_layers=32 +): + """ + return KV cache of MLA + """ + ret = [] + head_size = 576 + shape = [num_blocks, block_size, head_size] + + for i in range(num_layers): + kv = torch.rand(shape, dtype=dtype, device=device) + ret.append(kv) + + return ret + + +def generate_kv_cache_paged(num_blocks, device, block_size=16, dtype=torch.bfloat16): + ret = [] + num_layers = 32 + num_heads = 8 + head_size = 128 + shape = [num_blocks, block_size, num_heads, head_size] + + for i in range(num_layers): + k = torch.rand(shape, dtype=dtype, device=device) + v = torch.rand(shape, dtype=dtype, device=device) + ret.append((k, v)) + + return tuple(ret) + + +def generate_tokens(num_tokens, device, fixed=False): + if fixed: + return torch.tensor([-1] * num_tokens).to(device) + else: + # random tokens + return torch.randint(0, 10000, size=[num_tokens]).to(device) + + +def concatenate_kv_caches(kv_chunks, fmt): + dim = 1 if fmt == "huggingface" else 0 + ret = [] + for kv_layer in zip(*kv_chunks, strict=False): + klist, vlist = zip(*kv_layer, strict=False) + klayer = torch.cat(klist, dim=dim) + vlayer = torch.cat(vlist, dim=dim) + ret.append((klayer, vlayer)) + return tuple(ret) + + +def check_mem_obj_equal(left, right): + """ + check whether two memory objects are the same + """ + for left_mem_obj, right_mem_obj in zip(left, right, strict=False): + left_kv, right_kv = left_mem_obj.tensor, right_mem_obj.tensor + left_k, left_v = left_kv[0], left_kv[1] + right_k, right_v = right_kv[0], right_kv[1] + right_k = right_k.to(left_k.device) + right_v = right_v.to(left_v.device) + + assert len(left_k.shape) == 3 + assert len(left_v.shape) == 3 + assert len(right_k.shape) == 3 + assert len(right_v.shape) == 3 + + assert (left_k[:, :, :] == right_k[:, :, :]).all() + assert (left_v[:, :, :] == right_v[:, :, :]).all() + + +def check_paged_kv_cache_equal(left, right, slot_mapping, num_heads=8, head_size=128): + """ + check whether two paged kv caches are the same at slot_mapping + """ + token_dim = 0 + num_tokens = slot_mapping.shape[0] + for left_kv, right_kv in zip(left, right, strict=False): + left_k = left_kv[0].reshape(-1, num_heads, head_size) + left_v = left_kv[1].reshape(-1, num_heads, head_size) + right_k = right_kv[0].reshape(-1, num_heads, head_size) + right_v = right_kv[1].reshape(-1, num_heads, head_size) + + assert len(left_k.shape) == 3 + assert len(left_v.shape) == 3 + assert len(right_k.shape) == 3 + assert len(right_v.shape) == 3 + + assert left_k.shape[token_dim] >= num_tokens + assert left_v.shape[token_dim] >= num_tokens + assert right_k.shape[token_dim] >= num_tokens + assert right_v.shape[token_dim] >= num_tokens + + assert (left_k[slot_mapping, :, :] == right_k[slot_mapping, :, :]).all() + assert (left_v[slot_mapping, :, :] == right_v[slot_mapping, :, :]).all() + + +def check_sglang_paged_kv_cache_equal( + left, right, slot_mapping, num_heads=8, head_size=128 +): + """ + check whether two paged kv caches are the same at slot_mapping + """ + token_dim = 0 + num_tokens = slot_mapping.shape[0] + for left_kv, right_kv in zip(left, right, strict=False): + _left_kv = left_kv.reshape(-1, num_heads, head_size) + _right_kv = right_kv.reshape(-1, num_heads, head_size) + + assert len(_left_kv.shape) == 3 + assert len(_right_kv.shape) == 3 + + assert _left_kv.shape[token_dim] >= num_tokens + assert _right_kv.shape[token_dim] >= num_tokens + + assert (_left_kv[slot_mapping, :, :] == _right_kv[slot_mapping, :, :]).all() + + +def check_paged_kv_cache_equal_with_mla(left, right, slot_mapping, head_size=128): + """ + check whether two paged kv caches are the same at slot_mapping when use mla + """ + token_dim = 0 + num_tokens = slot_mapping.shape[0] + for left_kv, right_kv in zip(left, right, strict=False): + new_left_kv = left_kv.reshape(-1, head_size) + new_right_kv = right_kv.reshape(-1, head_size) + + assert len(new_left_kv.shape) == 2 + assert len(new_right_kv.shape) == 2 + + assert new_left_kv.shape[token_dim] >= num_tokens + assert new_right_kv.shape[token_dim] >= num_tokens + + assert (new_left_kv[slot_mapping, :] == new_right_kv[slot_mapping, :]).all() + + +def check_kv_cache_device(kvs, device): + for kv in kvs: + k, v = kv + assert k.device == torch.device(device) + assert v.device == torch.device(device) + + +def create_gpu_connector(hidden_dim, num_layers): + return VLLMPagedMemNPUConnectorV2(hidden_dim, num_layers) diff --git a/third_party/kvcache-ops b/third_party/kvcache-ops new file mode 160000 index 0000000..bdd216c --- /dev/null +++ b/third_party/kvcache-ops @@ -0,0 +1 @@ +Subproject commit bdd216cb6a2446494afcb26b68f65a2516e5a7ac