diff --git a/.github/workflows/docker-image.yml b/.github/workflows/docker-image.yml new file mode 100644 index 0000000..60df01f --- /dev/null +++ b/.github/workflows/docker-image.yml @@ -0,0 +1,90 @@ +name: DockerHub CI + +on: + release: + types: [published] + # push: + # branches: + # - main +env: + DOCKERHUB_REPO: ${{ secrets.DOCKERHUB_USERNAME }}/ktransformers +jobs: + test: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v2 + - name: Run tests + run: | + if [ -f docker-compose.test.yml ]; then + docker-compose --file docker-compose.test.yml build + docker-compose --file docker-compose.test.yml run sut + else + docker build . --file Dockerfile + fi + + docker_task: + needs: test + name: ${{ matrix.instruct}} + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + include: + # for amd64 + - {instruct: "FANCY", platform: "linux/amd64"} + - {instruct: "AVX512", platform: "linux/amd64"} + - {instruct: "AVX2", platform: "linux/amd64"} + - {instruct: "NATIVE", platform: "linux/amd64"} + # for arm64 + - {instruct: "NATIVE", platform: "linux/arm64"} + + steps: + - name: Move Docker data directory + run: | + sudo systemctl stop docker + sudo mkdir -p /mnt/docker + sudo rsync -avz /var/lib/docker/ /mnt/docker + sudo rm -rf /var/lib/docker + sudo ln -s /mnt/docker /var/lib/docker + sudo systemctl start docker + + - + name: Set up QEMU + uses: docker/setup-qemu-action@v3 + + - + name: Set up Docker Buildx + uses: docker/setup-buildx-action@v3 + + - + name: Login to Docker Hub + uses: docker/login-action@v3 + with: + username: ${{ secrets.DOCKERHUB_USERNAME }} + password: ${{ secrets.DOCKERHUB_TOKEN }} + - + name: Build and push for amd64 + if: matrix.platform == 'linux/amd64' + uses: docker/build-push-action@v6 + with: + push: true + platforms: | + linux/amd64 + tags: | + ${{ env.DOCKERHUB_REPO }}:latest-${{ matrix.instruct }} + ${{ env.DOCKERHUB_REPO }}:${{ github.event.release.tag_name }}-${{ matrix.instruct }} + build-args: | + CPU_INSTRUCT=${{ matrix.instruct }} + - + name: Build and push for arm64 + if: matrix.platform == 'linux/arm64' + uses: docker/build-push-action@v6 + with: + push: true + platforms: | + linux/arm64 + tags: | + ${{ env.DOCKERHUB_REPO }}:latest-${{ matrix.instruct }} + ${{ env.DOCKERHUB_REPO }}:${{ github.event.release.tag_name }}-${{ matrix.instruct }} + build-args: | + CPU_INSTRUCT=${{ matrix.instruct }} \ No newline at end of file diff --git a/.gitignore b/.gitignore index f8da261..58250d1 100644 --- a/.gitignore +++ b/.gitignore @@ -28,3 +28,4 @@ ktransformers/tests/chat_txt.txt mmlu_result_q4km.json mmlu_result_q4km.log ktransformers/tests/mmlu_result_silicon.log +ktransformers/ktransformers_ext/cuda_musa/ diff --git a/Dockerfile b/Dockerfile index 6d4b214..1807150 100644 --- a/Dockerfile +++ b/Dockerfile @@ -11,6 +11,7 @@ EOF FROM pytorch/pytorch:2.3.1-cuda12.1-cudnn8-devel as compile_server +ARG CPU_INSTRUCT=NATIVE WORKDIR /workspace ENV CUDA_HOME /usr/local/cuda COPY --from=web_compile /home/ktransformers /workspace/ktransformers @@ -28,8 +29,9 @@ git submodule init && git submodule update && pip install ninja pyproject numpy cpufeature && pip install flash-attn && -CPU_INSTRUCT=NATIVE KTRANSFORMERS_FORCE_BUILD=TRUE TORCH_CUDA_ARCH_LIST="8.0;8.6;8.7;8.9;9.0+PTX" pip install . --no-build-isolation --verbose && -pip cache purge +CPU_INSTRUCT=${CPU_INSTRUCT} KTRANSFORMERS_FORCE_BUILD=TRUE TORCH_CUDA_ARCH_LIST="8.0;8.6;8.7;8.9;9.0+PTX" pip install . --no-build-isolation --verbose && +pip cache purge && +cp /usr/lib/x86_64-linux-gnu/libstdc++.so.6 /opt/conda/lib/ EOF ENTRYPOINT ["tail", "-f", "/dev/null"] \ No newline at end of file diff --git a/README.md b/README.md index 61bacba..30f1425 100644 --- a/README.md +++ b/README.md @@ -103,7 +103,7 @@ Getting started with KTransformers is simple! Follow the steps below to set up a ### 📥 Installation -To install KTransformers, follow the official [Installation Guide](https://kvcache-ai.github.io/ktransformers/). +To install KTransformers, follow the official [Installation Guide](https://kvcache-ai.github.io/ktransformers/en/install.html).

📃 Brief Injection Tutorial

diff --git a/doc/en/install.md b/doc/en/install.md index b51c54f..269e8fb 100644 --- a/doc/en/install.md +++ b/doc/en/install.md @@ -81,7 +81,7 @@ Some preparation: git submodule update ``` - - [Optional] If you want to run with website, please [compile the website](./doc/en/api/server/website.md) before execute ```bash install.sh``` + - [Optional] If you want to run with website, please [compile the website](./api/server/website.md) before execute ```bash install.sh``` - For Linux - For simple install: @@ -103,7 +103,7 @@ Some preparation: install.bat ``` -* If you are developer, you can make use of the makefile to compile and format the code.
the detailed usage of makefile is [here](./doc/en/makefile_usage.md) +* If you are developer, you can make use of the makefile to compile and format the code.
the detailed usage of makefile is [here](./makefile_usage.md)

Local Chat

We provide a simple command-line local chat Python script that you can run for testing. diff --git a/ktransformers/ktransformers_ext/CMakeLists.txt b/ktransformers/ktransformers_ext/CMakeLists.txt index d9ecd7a..ecce9b7 100644 --- a/ktransformers/ktransformers_ext/CMakeLists.txt +++ b/ktransformers/ktransformers_ext/CMakeLists.txt @@ -30,6 +30,8 @@ if (NOT MSVC) option(LLAMA_F16C "llama: enable F16C" OFF) endif() option(LLAMA_AVX512_FANCY_SIMD "llama: enable AVX512-VL, AVX512-BW, AVX512-DQ, AVX512-VNNI" OFF) +option(KTRANSFORMERS_USE_CUDA "ktransformers: use CUDA" OFF) +option(KTRANSFORMERS_USE_MUSA "ktransformers: use MUSA" OFF) # Architecture specific # TODO: probably these flags need to be tweaked on some architectures @@ -208,8 +210,31 @@ include_directories(${CMAKE_CURRENT_SOURCE_DIR}/../../third_party) if (WIN32) include_directories("$ENV{CUDA_PATH}/include") elseif (UNIX) - find_package(CUDA REQUIRED) - include_directories("${CUDA_INCLUDE_DIRS}") + if (KTRANSFORMERS_USE_CUDA) + find_package(CUDA REQUIRED) + include_directories("${CUDA_INCLUDE_DIRS}") + add_compile_definitions(KTRANSFORMERS_USE_CUDA=1) + endif() + + if (KTRANSFORMERS_USE_MUSA) + if (NOT EXISTS $ENV{MUSA_PATH}) + if (NOT EXISTS /opt/musa) + set(MUSA_PATH /usr/local/musa) + else() + set(MUSA_PATH /opt/musa) + endif() + else() + set(MUSA_PATH $ENV{MUSA_PATH}) + endif() + + list(APPEND CMAKE_MODULE_PATH "${MUSA_PATH}/cmake") + + find_package(MUSAToolkit) + if (MUSAToolkit_FOUND) + message(STATUS "MUSA Toolkit found") + add_compile_definitions(KTRANSFORMERS_USE_MUSA=1) + endif() + endif() endif() aux_source_directory(${CMAKE_CURRENT_SOURCE_DIR} SOURCE_DIR1) @@ -225,10 +250,15 @@ target_link_libraries(${PROJECT_NAME} PRIVATE llama) if(WIN32) target_link_libraries(${PROJECT_NAME} PRIVATE "$ENV{CUDA_PATH}/lib/x64/cudart.lib")#CUDA::cudart elseif(UNIX) - if(NOT DEFINED ENV{CUDA_HOME} OR "$ENV{CUDA_HOME}" STREQUAL "") - set(ENV{CUDA_HOME} "/usr/local/cuda") + if(KTRANSFORMERS_USE_CUDA) + if(NOT DEFINED ENV{CUDA_HOME} OR "$ENV{CUDA_HOME}" STREQUAL "") + set(ENV{CUDA_HOME} "/usr/local/cuda") + endif() + target_link_libraries(${PROJECT_NAME} PRIVATE "$ENV{CUDA_HOME}/lib64/libcudart.so") + endif() + if(KTRANSFORMERS_USE_MUSA) + target_link_libraries(${PROJECT_NAME} PRIVATE MUSA::musart) endif() - target_link_libraries(${PROJECT_NAME} PRIVATE "$ENV{CUDA_HOME}/lib64/libcudart.so") endif() # Define the USE_NUMA option diff --git a/ktransformers/ktransformers_ext/cpu_backend/cpuinfer.h b/ktransformers/ktransformers_ext/cpu_backend/cpuinfer.h index 9618e6b..d0f0c60 100644 --- a/ktransformers/ktransformers_ext/cpu_backend/cpuinfer.h +++ b/ktransformers/ktransformers_ext/cpu_backend/cpuinfer.h @@ -17,7 +17,11 @@ #include #include #include -#include "cuda_runtime.h" +#ifdef KTRANSFORMERS_USE_CUDA +#include "vendors/cuda.h" +#elif KTRANSFORMERS_USE_MUSA +#include "vendors/musa.h" +#endif #include "backend.h" #include "task_queue.h" diff --git a/ktransformers/ktransformers_ext/cpu_backend/vendors/README.md b/ktransformers/ktransformers_ext/cpu_backend/vendors/README.md new file mode 100644 index 0000000..d179f66 --- /dev/null +++ b/ktransformers/ktransformers_ext/cpu_backend/vendors/README.md @@ -0,0 +1,3 @@ +## TODO + +This directory can be removed after updating the version of `llama.cpp`. \ No newline at end of file diff --git a/ktransformers/ktransformers_ext/cpu_backend/vendors/cuda.h b/ktransformers/ktransformers_ext/cpu_backend/vendors/cuda.h new file mode 100644 index 0000000..082ad2c --- /dev/null +++ b/ktransformers/ktransformers_ext/cpu_backend/vendors/cuda.h @@ -0,0 +1,3 @@ +#pragma once + +#include \ No newline at end of file diff --git a/ktransformers/ktransformers_ext/cpu_backend/vendors/musa.h b/ktransformers/ktransformers_ext/cpu_backend/vendors/musa.h new file mode 100644 index 0000000..7c94102 --- /dev/null +++ b/ktransformers/ktransformers_ext/cpu_backend/vendors/musa.h @@ -0,0 +1,7 @@ +#pragma once + +#include + +#define cudaLaunchHostFunc musaLaunchHostFunc +#define cudaStream_t musaStream_t +#define cudaHostFn_t musaHostFn_t \ No newline at end of file diff --git a/ktransformers/ktransformers_ext/cuda/binding.cpp b/ktransformers/ktransformers_ext/cuda/binding.cpp index 65c8bc4..0b1994d 100644 --- a/ktransformers/ktransformers_ext/cuda/binding.cpp +++ b/ktransformers/ktransformers_ext/cuda/binding.cpp @@ -1,15 +1,15 @@ /** * @Description : - * @Author : Azure-Tang + * @Author : Azure-Tang, Boxin Zhang * @Date : 2024-07-25 13:38:30 - * @Version : 1.0.0 - * @LastEditors : kkk1nak0 - * @LastEditTime : 2024-08-12 03:05:04 + * @Version : 0.2.2 * @Copyright (c) 2024 by KVCache.AI, All Rights Reserved. **/ #include "custom_gguf/ops.h" +#ifdef KTRANSFORMERS_USE_CUDA #include "gptq_marlin/ops.h" +#endif // Python bindings #include #include @@ -19,22 +19,46 @@ // namespace py = pybind11; PYBIND11_MODULE(KTransformersOps, m) { - m.def("dequantize_q8_0", &dequantize_q8_0, "Function to dequantize q8_0 data.", - py::arg("data"), py::arg("blk_size"), py::arg("device")); - m.def("dequantize_q6_k", &dequantize_q6_k, "Function to dequantize q6_k data.", - py::arg("data"), py::arg("blk_size"), py::arg("device")); - m.def("dequantize_q5_k", &dequantize_q5_k, "Function to dequantize q5_k data.", - py::arg("data"), py::arg("blk_size"), py::arg("device")); - m.def("dequantize_q4_k", &dequantize_q4_k, "Function to dequantize q4_k data.", - py::arg("data"), py::arg("blk_size"), py::arg("device")); - m.def("dequantize_q3_k", &dequantize_q3_k, "Function to dequantize q3_k data.", - py::arg("data"), py::arg("blk_size"), py::arg("device")); - m.def("dequantize_q2_k", &dequantize_q2_k, "Function to dequantize q2_k data.", - py::arg("data"), py::arg("blk_size"), py::arg("device")); - m.def("dequantize_iq4_xs", &dequantize_iq4_xs, "Function to dequantize iq4_xs data.", - py::arg("data"), py::arg("blk_size"), py::arg("device")); - m.def("gptq_marlin_gemm", &gptq_marlin_gemm, "Function to perform GEMM using Marlin quantization.", - py::arg("a"), py::arg("b_q_weight"), py::arg("b_scales"), py::arg("g_idx"), - py::arg("perm"), py::arg("workspace"), py::arg("num_bits"), py::arg("size_m"), - py::arg("size_n"), py::arg("size_k"), py::arg("is_k_full")); + + m.def("dequantize_q8_0", [](const intptr_t data, int num_bytes, int blk_size, const int ele_per_blk, torch::Device device, torch::Dtype target_dtype) { + return dequantize_q8_0((int8_t*)data, num_bytes, blk_size, ele_per_blk, device, target_dtype); + }, "Function to dequantize q8_0 data.", + py::arg("data"), py::arg("num_bytes"), py::arg("blk_size"), py::arg("ele_per_blk"), py::arg("device"), py::arg("target_dtype")); + + m.def("dequantize_q6_k", [](const intptr_t data, int num_bytes, int blk_size, const int ele_per_blk, torch::Device device, torch::Dtype target_dtype) { + return dequantize_q6_k((int8_t*)data, num_bytes, blk_size, ele_per_blk, device, target_dtype); + }, "Function to dequantize q6_k data.", + py::arg("data"), py::arg("num_bytes"), py::arg("blk_size"), py::arg("ele_per_blk"), py::arg("device"), py::arg("target_dtype")); + + m.def("dequantize_q5_k", [](const intptr_t data, int num_bytes, int blk_size, const int ele_per_blk, torch::Device device, torch::Dtype target_dtype) { + return dequantize_q5_k((int8_t*)data, num_bytes, blk_size, ele_per_blk, device, target_dtype); + }, "Function to dequantize q5_k data.", + py::arg("data"), py::arg("num_bytes"), py::arg("blk_size"), py::arg("ele_per_blk"), py::arg("device"), py::arg("target_dtype")); + + m.def("dequantize_q4_k", [](const intptr_t data, int num_bytes, int blk_size, const int ele_per_blk, torch::Device device, torch::Dtype target_dtype) { + return dequantize_q4_k((int8_t*)data, num_bytes, blk_size, ele_per_blk, device, target_dtype); + }, "Function to dequantize q4_k data.", + py::arg("data"), py::arg("num_bytes"), py::arg("blk_size"), py::arg("ele_per_blk"), py::arg("device"), py::arg("target_dtype")); + + m.def("dequantize_q3_k", [](const intptr_t data, int num_bytes, int blk_size, const int ele_per_blk, torch::Device device, torch::Dtype target_dtype) { + return dequantize_q3_k((int8_t*)data, num_bytes, blk_size, ele_per_blk, device, target_dtype); + }, "Function to dequantize q3_k data.", + py::arg("data"), py::arg("num_bytes"), py::arg("blk_size"), py::arg("ele_per_blk"), py::arg("device"), py::arg("target_dtype")); + + m.def("dequantize_q2_k", [](const intptr_t data, int num_bytes, int blk_size, const int ele_per_blk, torch::Device device, torch::Dtype target_dtype) { + return dequantize_q2_k((int8_t*)data, num_bytes, blk_size, ele_per_blk, device, target_dtype); + }, "Function to dequantize q2_k data.", + py::arg("data"), py::arg("num_bytes"), py::arg("blk_size"), py::arg("ele_per_blk"), py::arg("device"), py::arg("target_dtype")); + + m.def("dequantize_iq4_xs", [](const intptr_t data, int num_bytes, int blk_size, const int ele_per_blk, torch::Device device, torch::Dtype target_dtype) { + return dequantize_iq4_xs((int8_t*)data, num_bytes, blk_size, ele_per_blk, device, target_dtype); + }, "Function to dequantize iq4_xs data.", + py::arg("data"), py::arg("num_bytes"), py::arg("blk_size"), py::arg("ele_per_blk"), py::arg("device"), py::arg("target_dtype")); + +#ifdef KTRANSFORMERS_USE_CUDA + m.def("gptq_marlin_gemm", &gptq_marlin_gemm, "Function to perform GEMM using Marlin quantization.", + py::arg("a"), py::arg("b_q_weight"), py::arg("b_scales"), py::arg("g_idx"), + py::arg("perm"), py::arg("workspace"), py::arg("num_bits"), py::arg("size_m"), + py::arg("size_n"), py::arg("size_k"), py::arg("is_k_full")); +#endif } diff --git a/ktransformers/ktransformers_ext/cuda/custom_gguf/binding.cpp b/ktransformers/ktransformers_ext/cuda/custom_gguf/binding.cpp deleted file mode 100644 index 99069d8..0000000 --- a/ktransformers/ktransformers_ext/cuda/custom_gguf/binding.cpp +++ /dev/null @@ -1,35 +0,0 @@ -#include "ops.h" -// Python bindings -#include -#include -#include -#include -#include -// namespace py = pybind11; - -int test(){ - return 5; -} - -torch::Tensor dequantize_q6_k(torch::Tensor data, int blk_size, torch::Device device); -torch::Tensor dequantize_q5_k(torch::Tensor data, int blk_size, torch::Device device); -torch::Tensor dequantize_q2_k(torch::Tensor data, int blk_size, torch::Device device); - -PYBIND11_MODULE(cudaops, m) { - m.def("dequantize_q8_0", &dequantize_q8_0, "Function to dequantize q8_0 data.", - py::arg("data"), py::arg("blk_size"), py::arg("device")); - m.def("dequantize_q6_k", &dequantize_q6_k, "Function to dequantize q6_k data.", - py::arg("data"), py::arg("blk_size"), py::arg("device")); - m.def("dequantize_q5_k", &dequantize_q5_k, "Function to dequantize q5_k data.", - py::arg("data"), py::arg("blk_size"), py::arg("device")); - m.def("dequantize_q4_k", &dequantize_q4_k, "Function to dequantize q4_k data.", - py::arg("data"), py::arg("blk_size"), py::arg("device")); - m.def("dequantize_q3_k", &dequantize_q3_k, "Function to dequantize q3_k data.", - py::arg("data"), py::arg("blk_size"), py::arg("device")); - m.def("dequantize_q2_k", &dequantize_q2_k, "Function to dequantize q2_k data.", - py::arg("data"), py::arg("blk_size"), py::arg("device")); - m.def("dequantize_iq4_xs", &dequantize_iq4_xs, "Function to dequantize iq4_xs data.", - py::arg("data"), py::arg("blk_size"), py::arg("device")); - m.def("test", &test, "Function to test."); - -} diff --git a/ktransformers/ktransformers_ext/cuda/custom_gguf/dequant.cu b/ktransformers/ktransformers_ext/cuda/custom_gguf/dequant.cu index d5b4a2c..e80efc4 100644 --- a/ktransformers/ktransformers_ext/cuda/custom_gguf/dequant.cu +++ b/ktransformers/ktransformers_ext/cuda/custom_gguf/dequant.cu @@ -2,26 +2,55 @@ * @Description : * @Author : Azure-Tang, Boxin Zhang * @Date : 2024-07-25 13:38:30 - * @Version : 1.0.0 - * @LastEditors : kkk1nak0 - * @LastEditTime : 2024-08-12 04:18:04 + * @Version : 0.2.2 * Adapted from https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.c * Copyright (c) 2023-2024 The ggml authors * Copyright (c) 2024 by KVCache.AI, All Rights Reserved. */ #include +#include +#include #include #include #include #include #include -__global__ void dequantize_q8_0_kernel(float* output, const float* scales, const int8_t* qs, int num_blocks, int blk_size) { +__global__ void dequantize_q8_0_fp32_kernel(const int8_t* data, float* output, const int blk_size, const int ele_per_blk, const int num_blocks) { long long global_idx = blockIdx.x * blockDim.x + threadIdx.x; - for (long long block_id=global_idx; block_id(data + block_id * blk_size + 80))); - const float min = __half2float(*(reinterpret_cast(data + block_id * blk_size + 82))); + const float d = __half2float(*(reinterpret_cast(data + block_id * blk_size + 80))); + const float min = __half2float(*(reinterpret_cast(data + block_id * blk_size + 82))); const uint8_t * __restrict__ q = (uint8_t*)(data + block_id * blk_size + 16); @@ -70,17 +99,85 @@ __global__ void dequantize_q2_k_kernel(int8_t* data, float* output, int blk_size } } -__global__ void dequantize_q3_k_kernel(int8_t* data, float* output, int blk_size, int num_blocks) { +__global__ void dequantize_q2_k_fp16_kernel(const int8_t* data, __half* output, const int blk_size, const int ele_per_blk, const int num_blocks) { + long long global_idx = blockIdx.x * blockDim.x + threadIdx.x; + for (long long block_id=global_idx; block_id(data + block_id * blk_size + 80))); + const float min = __half2float(*(reinterpret_cast(data + block_id * blk_size + 82))); + + const uint8_t * __restrict__ q = (uint8_t*)(data + block_id * blk_size + 16); + + int is = 0; + float dl, ml; + + for (int n = 0; n < 256; n += 128) { + int shift = 0; + for (int j = 0; j < 4; ++j) { + uint8_t* scales = (uint8_t*)(data + block_id * blk_size + (is++)); + uint8_t sc = *scales; + dl = d * (sc & 0xF); ml = min * (sc >> 4); + for (int l = 0; l < 16; ++l) *output_blk++ = __float2half(dl * ((int8_t)((q[l] >> shift) & 3)) - ml); + + scales = (uint8_t*)(data + block_id * blk_size + (is++)); + sc = *scales; + + dl = d * (sc & 0xF); ml = min * (sc >> 4); + for (int l = 0; l < 16; ++l) *output_blk++ = __float2half(dl * ((int8_t)((q[l+16] >> shift) & 3)) - ml); + + shift += 2; + } + q += 32; + } + } +} + +__global__ void dequantize_q2_k_bf16_kernel(const int8_t* data, nv_bfloat16* output, const int blk_size, const int ele_per_blk, const int num_blocks) { + long long global_idx = blockIdx.x * blockDim.x + threadIdx.x; + for (long long block_id=global_idx; block_id(data + block_id * blk_size + 80))); + const float min = __half2float(*(reinterpret_cast(data + block_id * blk_size + 82))); + + const uint8_t * __restrict__ q = (uint8_t*)(data + block_id * blk_size + 16); + + int is = 0; + float dl, ml; + + for (int n = 0; n < 256; n += 128) { + int shift = 0; + for (int j = 0; j < 4; ++j) { + uint8_t* scales = (uint8_t*)(data + block_id * blk_size + (is++)); + uint8_t sc = *scales; + dl = d * (sc & 0xF); ml = min * (sc >> 4); + for (int l = 0; l < 16; ++l) *output_blk++ = __float2bfloat16(dl * ((int8_t)((q[l] >> shift) & 3)) - ml); + + scales = (uint8_t*)(data + block_id * blk_size + (is++)); + sc = *scales; + + dl = d * (sc & 0xF); ml = min * (sc >> 4); + for (int l = 0; l < 16; ++l) *output_blk++ = __float2bfloat16(dl * ((int8_t)((q[l+16] >> shift) & 3)) - ml); + + shift += 2; + } + q += 32; + } + } +} + +__global__ void dequantize_q3_k_fp32_kernel(const int8_t* data, float* output, const int blk_size, const int ele_per_blk, const int num_blocks) { long long global_idx = blockIdx.x * blockDim.x + threadIdx.x; const uint32_t kmask1 = 0x03030303; const uint32_t kmask2 = 0x0f0f0f0f; for (long long block_id=global_idx; block_id(data + block_id * blk_size + 108))); + const float d_all = __half2float(*(reinterpret_cast(data + block_id * blk_size + 108))); const uint8_t * __restrict__ q = (uint8_t*)(data + block_id * blk_size + 32); const uint8_t * __restrict__ hm = (uint8_t*)(data + block_id * blk_size + 0); @@ -126,19 +223,131 @@ __global__ void dequantize_q3_k_kernel(int8_t* data, float* output, int blk_size } } +__global__ void dequantize_q3_k_fp16_kernel(const int8_t* data, __half* output, const int blk_size, const int ele_per_blk, const int num_blocks) { + + long long global_idx = blockIdx.x * blockDim.x + threadIdx.x; + const uint32_t kmask1 = 0x03030303; + const uint32_t kmask2 = 0x0f0f0f0f; + for (long long block_id=global_idx; block_id(data + block_id * blk_size + 108))); + + const uint8_t * __restrict__ q = (uint8_t*)(data + block_id * blk_size + 32); + const uint8_t * __restrict__ hm = (uint8_t*)(data + block_id * blk_size + 0); + uint8_t m = 1; + + + uint8_t* block_scales = (uint8_t*)(data + block_id * blk_size + 96); + + for (int i = 0; i < 3; i++) { + aux[i] = 0; + for (int j = 0; j < 4; j++) { + aux[i] |= ((uint32_t)block_scales[i * 4 + j]) << (j * 8); + } + } + + uint32_t tmp = aux[2]; + aux[2] = ((aux[0] >> 4) & kmask2) | (((tmp >> 4) & kmask1) << 4); + aux[3] = ((aux[1] >> 4) & kmask2) | (((tmp >> 6) & kmask1) << 4); + aux[0] = (aux[0] & kmask2) | (((tmp >> 0) & kmask1) << 4); + aux[1] = (aux[1] & kmask2) | (((tmp >> 2) & kmask1) << 4); + + int is = 0; + float dl; + for (int n = 0; n < 256; n += 128) { + int shift = 0; + for (int j = 0; j < 4; ++j) { + + dl = d_all * (scales[is++] - 32); + for (int l = 0; l < 16; ++l) { + *output_blk++ = __float2half(dl * ((int8_t)((q[l+ 0] >> shift) & 3) - ((hm[l+ 0] & m) ? 0 : 4))); + } + + dl = d_all * (scales[is++] - 32); + for (int l = 0; l < 16; ++l) { + *output_blk++ = __float2half(dl * ((int8_t)((q[l+16] >> shift) & 3) - ((hm[l+16] & m) ? 0 : 4))); + } + + shift += 2; + m <<= 1; + } + q += 32; + } + } +} + +__global__ void dequantize_q3_k_bf16_kernel(const int8_t* data, nv_bfloat16* output, const int blk_size, const int ele_per_blk, const int num_blocks) { + + long long global_idx = blockIdx.x * blockDim.x + threadIdx.x; + const uint32_t kmask1 = 0x03030303; + const uint32_t kmask2 = 0x0f0f0f0f; + for (long long block_id=global_idx; block_id(data + block_id * blk_size + 108))); + + const uint8_t * __restrict__ q = (uint8_t*)(data + block_id * blk_size + 32); + const uint8_t * __restrict__ hm = (uint8_t*)(data + block_id * blk_size + 0); + uint8_t m = 1; + + + uint8_t* block_scales = (uint8_t*)(data + block_id * blk_size + 96); + + for (int i = 0; i < 3; i++) { + aux[i] = 0; + for (int j = 0; j < 4; j++) { + aux[i] |= ((uint32_t)block_scales[i * 4 + j]) << (j * 8); + } + } + + uint32_t tmp = aux[2]; + aux[2] = ((aux[0] >> 4) & kmask2) | (((tmp >> 4) & kmask1) << 4); + aux[3] = ((aux[1] >> 4) & kmask2) | (((tmp >> 6) & kmask1) << 4); + aux[0] = (aux[0] & kmask2) | (((tmp >> 0) & kmask1) << 4); + aux[1] = (aux[1] & kmask2) | (((tmp >> 2) & kmask1) << 4); + + int is = 0; + float dl; + for (int n = 0; n < 256; n += 128) { + int shift = 0; + for (int j = 0; j < 4; ++j) { + + dl = d_all * (scales[is++] - 32); + for (int l = 0; l < 16; ++l) { + *output_blk++ = __float2bfloat16(dl * ((int8_t)((q[l+ 0] >> shift) & 3) - ((hm[l+ 0] & m) ? 0 : 4))); + } + + dl = d_all * (scales[is++] - 32); + for (int l = 0; l < 16; ++l) { + *output_blk++ = __float2bfloat16(dl * ((int8_t)((q[l+16] >> shift) & 3) - ((hm[l+16] & m) ? 0 : 4))); + } + + shift += 2; + m <<= 1; + } + q += 32; + } + } +} + + +__global__ void dequantize_q4_k_fp32_kernel(const int8_t* data, float* output, const int blk_size, const int ele_per_blk, const int num_blocks) { long long global_idx = blockIdx.x * blockDim.x + threadIdx.x; for (long long block_id=global_idx; block_id(data + block_id * 144 + 0))); - const float min = __half2float(*(reinterpret_cast(data + block_id * 144 + 2))); + const float d = __half2float(*(reinterpret_cast(data + block_id * 144 + 0))); + const float min = __half2float(*(reinterpret_cast(data + block_id * 144 + 2))); int is = 0; uint8_t sc, m; - for (int j = 0; j < blk_size; j += 64) { + for (int j = 0; j < ele_per_blk; j += 64) { uint8_t* scales = (uint8_t*)(data + block_id * 144 + 4); get_scale_min_k4(is + 0, scales, &sc, &m); const float d1 = d * sc; const float m1 = min * m; @@ -151,13 +360,61 @@ __global__ void dequantize_q4_k_kernel(int8_t* data, float* output, int blk_size } } -__global__ void dequantize_q5_k_kernel(int8_t* data, float* output, int blk_size, int num_blocks) { +__global__ void dequantize_q4_k_fp16_kernel(const int8_t* data, __half* output, const int blk_size, const int ele_per_blk, const int num_blocks) { + long long global_idx = blockIdx.x * blockDim.x + threadIdx.x; + for (long long block_id=global_idx; block_id(data + block_id * 144 + 0))); + const float min = __half2float(*(reinterpret_cast(data + block_id * 144 + 2))); + int is = 0; + uint8_t sc, m; + for (int j = 0; j < ele_per_blk; j += 64) { + uint8_t* scales = (uint8_t*)(data + block_id * 144 + 4); + get_scale_min_k4(is + 0, scales, &sc, &m); + const float d1 = d * sc; const float m1 = min * m; + get_scale_min_k4(is + 1, scales, &sc, &m); + const float d2 = d * sc; const float m2 = min * m; + for (int l = 0; l < 32; ++l) *output_blk++ = __float2half(d1 * (q[l] & 0xF) - m1); + for (int l = 0; l < 32; ++l) *output_blk++ = __float2half(d2 * (q[l] >> 4) - m2); + q += 32; is += 2; + } + } +} + +__global__ void dequantize_q4_k_bf16_kernel(const int8_t* data, nv_bfloat16* output, const int blk_size, const int ele_per_blk, const int num_blocks) { + long long global_idx = blockIdx.x * blockDim.x + threadIdx.x; + for (long long block_id=global_idx; block_id(data + block_id * 144 + 0))); + const float min = __half2float(*(reinterpret_cast(data + block_id * 144 + 2))); + int is = 0; + uint8_t sc, m; + for (int j = 0; j < ele_per_blk; j += 64) { + uint8_t* scales = (uint8_t*)(data + block_id * 144 + 4); + get_scale_min_k4(is + 0, scales, &sc, &m); + const float d1 = d * sc; const float m1 = min * m; + get_scale_min_k4(is + 1, scales, &sc, &m); + const float d2 = d * sc; const float m2 = min * m; + for (int l = 0; l < 32; ++l) *output_blk++ = __float2bfloat16(d1 * (q[l] & 0xF) - m1); + for (int l = 0; l < 32; ++l) *output_blk++ = __float2bfloat16(d2 * (q[l] >> 4) - m2); + q += 32; is += 2; + } + } +} + +__global__ void dequantize_q5_k_fp32_kernel(const int8_t* data, float* output, const int blk_size, const int ele_per_blk, const int num_blocks) { long long global_idx = blockIdx.x * blockDim.x + threadIdx.x; for (long long block_id = global_idx; block_id < num_blocks; block_id += blockDim.x * gridDim.x){ - float* __restrict__ output_blk = (float*)(output + block_id * 256); + float* __restrict__ output_blk = (float*)(output + block_id * ele_per_blk); - const float d = __half2float(*(reinterpret_cast(data + block_id * blk_size + 0))); - const float min = __half2float(*(reinterpret_cast(data + block_id * blk_size + 2))); + const float d = __half2float(*(reinterpret_cast(data + block_id * blk_size + 0))); + const float min = __half2float(*(reinterpret_cast(data + block_id * blk_size + 2))); const uint8_t * __restrict__ qh = (uint8_t*)(data + block_id * blk_size + 16); const uint8_t * __restrict__ ql = (uint8_t*)(data + block_id * blk_size + 48); @@ -180,46 +437,165 @@ __global__ void dequantize_q5_k_kernel(int8_t* data, float* output, int blk_size } } -__global__ void dequantize_q6_k_kernel(int8_t* data, float* output, int blk_size, int num_blocks) { +__global__ void dequantize_q5_k_fp16_kernel(const int8_t* data, __half* output, const int blk_size, const int ele_per_blk, const int num_blocks) { + long long global_idx = blockIdx.x * blockDim.x + threadIdx.x; + for (long long block_id = global_idx; block_id < num_blocks; block_id += blockDim.x * gridDim.x){ + __half* __restrict__ output_blk = (__half*)(output + block_id * ele_per_blk); + + const float d = __half2float(*(reinterpret_cast(data + block_id * blk_size + 0))); + const float min = __half2float(*(reinterpret_cast(data + block_id * blk_size + 2))); + + const uint8_t * __restrict__ qh = (uint8_t*)(data + block_id * blk_size + 16); + const uint8_t * __restrict__ ql = (uint8_t*)(data + block_id * blk_size + 48); + + int is = 0; + uint8_t sc, m; + uint8_t u1 = 1, u2 = 2; + uint8_t* scales = (uint8_t*)(data + block_id * blk_size + 4); + + for (int j = 0; j < 256; j += 64) { + get_scale_min_k4(is + 0, scales, &sc, &m); + const float d1 = d * sc; const float m1 = min * m; + get_scale_min_k4(is + 1, scales, &sc, &m); + const float d2 = d * sc; const float m2 = min * m; + for (int l = 0; l < 32; ++l) *output_blk++ = __float2half(d1 * ((ql[l] & 0xF) + (qh[l] & u1 ? 16 : 0)) - m1); + for (int l = 0; l < 32; ++l) *output_blk++ = __float2half(d2 * ((ql[l] >> 4) + (qh[l] & u2 ? 16 : 0)) - m2); + ql += 32; is += 2; + u1 <<= 2; u2 <<= 2; + } + } +} + +__global__ void dequantize_q5_k_bf16_kernel(const int8_t* data, nv_bfloat16* output, const int blk_size, const int ele_per_blk, const int num_blocks) { + long long global_idx = blockIdx.x * blockDim.x + threadIdx.x; + for (long long block_id = global_idx; block_id < num_blocks; block_id += blockDim.x * gridDim.x){ + nv_bfloat16* __restrict__ output_blk = (nv_bfloat16*)(output + block_id * ele_per_blk); + + const float d = __half2float(*(reinterpret_cast(data + block_id * blk_size + 0))); + const float min = __half2float(*(reinterpret_cast(data + block_id * blk_size + 2))); + + const uint8_t * __restrict__ qh = (uint8_t*)(data + block_id * blk_size + 16); + const uint8_t * __restrict__ ql = (uint8_t*)(data + block_id * blk_size + 48); + + int is = 0; + uint8_t sc, m; + uint8_t u1 = 1, u2 = 2; + uint8_t* scales = (uint8_t*)(data + block_id * blk_size + 4); + + for (int j = 0; j < 256; j += 64) { + get_scale_min_k4(is + 0, scales, &sc, &m); + const float d1 = d * sc; const float m1 = min * m; + get_scale_min_k4(is + 1, scales, &sc, &m); + const float d2 = d * sc; const float m2 = min * m; + for (int l = 0; l < 32; ++l) *output_blk++ = __float2bfloat16(d1 * ((ql[l] & 0xF) + (qh[l] & u1 ? 16 : 0)) - m1); + for (int l = 0; l < 32; ++l) *output_blk++ = __float2bfloat16(d2 * ((ql[l] >> 4) + (qh[l] & u2 ? 16 : 0)) - m2); + ql += 32; is += 2; + u1 <<= 2; u2 <<= 2; + } + } +} + +__global__ void dequantize_q6_k_fp32_kernel(const int8_t* data, float* output, const int blk_size, const int ele_per_blk, const int num_blocks) { long long global_idx = blockIdx.x * blockDim.x + threadIdx.x; for (long long block_id=global_idx; block_id(data + block_id * blk_size + 208))); + float* __restrict__ output_blk = (float*)(output + block_id * ele_per_blk); + const float d = __half2float(*(reinterpret_cast(data + block_id * blk_size + 208))); const uint8_t * __restrict__ ql = (uint8_t*)(data + block_id * blk_size); const uint8_t * __restrict__ qh = (uint8_t*)(data + block_id * blk_size + 128); const int8_t * __restrict__ sc = (int8_t*)(data + block_id * blk_size + 192); - //if (blk_size == 256){ - for (int n = 0; n < blk_size; n += 128) { - for (int l = 0; l < 32; ++l) { - int is = l/16; - const int8_t q1 = (int8_t)((ql[l + 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32; - const int8_t q2 = (int8_t)((ql[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32; - const int8_t q3 = (int8_t)((ql[l + 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32; - const int8_t q4 = (int8_t)((ql[l + 32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32; - output_blk[l + 0] = d * sc[is + 0] * q1; - output_blk[l + 32] = d * sc[is + 2] * q2; - output_blk[l + 64] = d * sc[is + 4] * q3; - output_blk[l + 96] = d * sc[is + 6] * q4; - } - output_blk += 128; - ql += 64; - qh += 32; - sc += 8; + for (int n = 0; n < ele_per_blk; n += 128) { + for (int l = 0; l < 32; ++l) { + int is = l/16; + const int8_t q1 = (int8_t)((ql[l + 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32; + const int8_t q2 = (int8_t)((ql[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32; + const int8_t q3 = (int8_t)((ql[l + 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32; + const int8_t q4 = (int8_t)((ql[l + 32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32; + output_blk[l + 0] = d * sc[is + 0] * q1; + output_blk[l + 32] = d * sc[is + 2] * q2; + output_blk[l + 64] = d * sc[is + 4] * q3; + output_blk[l + 96] = d * sc[is + 6] * q4; } + output_blk += 128; + ql += 64; + qh += 32; + sc += 8; + } + } +} + +__global__ void dequantize_q6_k_fp16_kernel(const int8_t* data, __half* output, const int blk_size, const int ele_per_blk, const int num_blocks) { + long long global_idx = blockIdx.x * blockDim.x + threadIdx.x; + for (long long block_id=global_idx; block_id(data + block_id * blk_size + 208))); + + const uint8_t * __restrict__ ql = (uint8_t*)(data + block_id * blk_size); + const uint8_t * __restrict__ qh = (uint8_t*)(data + block_id * blk_size + 128); + const int8_t * __restrict__ sc = (int8_t*)(data + block_id * blk_size + 192); + + + for (int n = 0; n < ele_per_blk; n += 128) { + for (int l = 0; l < 32; ++l) { + int is = l/16; + const int8_t q1 = (int8_t)((ql[l + 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32; + const int8_t q2 = (int8_t)((ql[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32; + const int8_t q3 = (int8_t)((ql[l + 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32; + const int8_t q4 = (int8_t)((ql[l + 32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32; + output_blk[l + 0] = __float2half(d * sc[is + 0] * q1); + output_blk[l + 32] = __float2half(d * sc[is + 2] * q2); + output_blk[l + 64] = __float2half(d * sc[is + 4] * q3); + output_blk[l + 96] = __float2half(d * sc[is + 6] * q4); + } + output_blk += 128; + ql += 64; + qh += 32; + sc += 8; + } + } +} + +__global__ void dequantize_q6_k_bf16_kernel(const int8_t* data, nv_bfloat16* output, const int blk_size, const int ele_per_blk, const int num_blocks) { + long long global_idx = blockIdx.x * blockDim.x + threadIdx.x; + for (long long block_id=global_idx; block_id(data + block_id * blk_size + 208))); + + const uint8_t * __restrict__ ql = (uint8_t*)(data + block_id * blk_size); + const uint8_t * __restrict__ qh = (uint8_t*)(data + block_id * blk_size + 128); + const int8_t * __restrict__ sc = (int8_t*)(data + block_id * blk_size + 192); + + + for (int n = 0; n < ele_per_blk; n += 128) { + for (int l = 0; l < 32; ++l) { + int is = l/16; + const int8_t q1 = (int8_t)((ql[l + 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32; + const int8_t q2 = (int8_t)((ql[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32; + const int8_t q3 = (int8_t)((ql[l + 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32; + const int8_t q4 = (int8_t)((ql[l + 32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32; + output_blk[l + 0] = __float2bfloat16(d * sc[is + 0] * q1); + output_blk[l + 32] = __float2bfloat16(d * sc[is + 2] * q2); + output_blk[l + 64] = __float2bfloat16(d * sc[is + 4] * q3); + output_blk[l + 96] = __float2bfloat16(d * sc[is + 6] * q4); + } + output_blk += 128; + ql += 64; + qh += 32; + sc += 8; + } } } static constexpr __device__ int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113}; -__global__ void dequantize_iq4_xs_kernel(int8_t* data, float* output, int blk_size, int num_blocks) { +__global__ void dequantize_iq4_xs_fp32_kernel(const int8_t* data, float* output, const int blk_size, const int ele_per_blk, const int num_blocks) { long long global_idx = blockIdx.x * blockDim.x + threadIdx.x; for (long long block_id=global_idx; block_id(data + block_id * blk_size))); - const uint16_t scales_h = *(reinterpret_cast(data + block_id * blk_size + 2)); + float* __restrict__ output_blk = (float*)(output + block_id * ele_per_blk); + const float d = __half2float(*(reinterpret_cast(data + block_id * blk_size))); + const uint16_t scales_h = *(reinterpret_cast(data + block_id * blk_size + 2)); const uint8_t* scales_l = (uint8_t*)(data + block_id * blk_size + 2 + 2); const uint8_t* qs = (uint8_t*)(data + block_id * blk_size + 2 + 2 + 4); @@ -236,152 +612,267 @@ __global__ void dequantize_iq4_xs_kernel(int8_t* data, float* output, int blk_si } } -torch::Tensor dequantize_q8_0(torch::Tensor data, int blk_size, torch::Device device) { - int num_blocks = data.numel() / blk_size; +__global__ void dequantize_iq4_xs_fp16_kernel(const int8_t* data, __half* output, const int blk_size, const int ele_per_blk, const int num_blocks) { + long long global_idx = blockIdx.x * blockDim.x + threadIdx.x; + for (long long block_id=global_idx; block_id(data + block_id * blk_size))); + const uint16_t scales_h = *(reinterpret_cast(data + block_id * blk_size + 2)); + const uint8_t* scales_l = (uint8_t*)(data + block_id * blk_size + 2 + 2); + const uint8_t* qs = (uint8_t*)(data + block_id * blk_size + 2 + 2 + 4); + + for (int ib = 0; ib < 8; ++ib) { + const int ls = ((scales_l[ib / 2] >> 4 * (ib % 2)) & 0xf) | (((scales_h >> 2 * ib) & 3) << 4); + const float dl = d * (ls - 32); + for (int j = 0; j < 16; ++j) { + output_blk[j + 0] = __float2half(dl * kvalues_iq4nl[qs[j] & 0xf]); + output_blk[j + 16] = __float2half(dl * kvalues_iq4nl[qs[j] >> 4]); + } + output_blk += 32; + qs += 16; + } + } +} + +__global__ void dequantize_iq4_xs_bf16_kernel(const int8_t* data, nv_bfloat16* output, const int blk_size, const int ele_per_blk, const int num_blocks) { + long long global_idx = blockIdx.x * blockDim.x + threadIdx.x; + for (long long block_id=global_idx; block_id(data + block_id * blk_size))); + const uint16_t scales_h = *(reinterpret_cast(data + block_id * blk_size + 2)); + const uint8_t* scales_l = (uint8_t*)(data + block_id * blk_size + 2 + 2); + const uint8_t* qs = (uint8_t*)(data + block_id * blk_size + 2 + 2 + 4); + + for (int ib = 0; ib < 8; ++ib) { + const int ls = ((scales_l[ib / 2] >> 4 * (ib % 2)) & 0xf) | (((scales_h >> 2 * ib) & 3) << 4); + const float dl = d * (ls - 32); + for (int j = 0; j < 16; ++j) { + output_blk[j + 0] = __float2bfloat16(dl * kvalues_iq4nl[qs[j] & 0xf]); + output_blk[j + 16] = __float2bfloat16(dl * kvalues_iq4nl[qs[j] >> 4]); + } + output_blk += 32; + qs += 16; + } + } +} + +torch::Tensor dequantize_q8_0(const int8_t* data, const int num_bytes, const int blk_size, const int ele_per_blk, const torch::Device device, const torch::Dtype target_dtype) { + int num_blocks = num_bytes / blk_size; const at::cuda::OptionalCUDAGuard device_guard(device); - // create gpu - auto options_scales = torch::TensorOptions().dtype(torch::kFloat32).device(device).memory_format(torch::MemoryFormat::Contiguous); - auto options_qs = torch::TensorOptions().dtype(torch::kInt8).device(device).memory_format(torch::MemoryFormat::Contiguous); - auto scales_gpu = torch::empty({{num_blocks, 1}}, options_scales); - auto qs_gpu = torch::empty({num_blocks, 32}, options_qs); - // read on cpu - options_scales = torch::TensorOptions().dtype(torch::kFloat16).device(torch::kCPU); - options_qs = torch::TensorOptions().dtype(torch::kInt8).device(torch::kCPU); + auto options = torch::TensorOptions().dtype(torch::kInt8).device(device).memory_format(torch::MemoryFormat::Contiguous); + auto data_gpu = torch::empty({ num_bytes }, options); - // // reinterpret - auto scales = torch::from_blob(data.data_ptr(), {num_blocks, 1 + 16}, options_scales).slice(1, 0, 1); - auto qs = torch::from_blob(data.data_ptr(), {num_blocks, 2 + 32}, options_qs).slice(1, 2); - - auto scales_f32 = scales.to(torch::kFloat32); - scales_gpu.copy_(scales_f32, false); - qs_gpu.copy_(qs, false); + cudaMemcpy(data_gpu.data_ptr(), data, num_bytes, cudaMemcpyHostToDevice); + //data_gpu.copy_(data, false); // Create output tensor - auto output = torch::zeros_like(qs, torch::dtype(torch::kFloat32).device(device)); + auto output = torch::zeros({ num_blocks, 32 }, torch::dtype(target_dtype).device(device)); - // Launch kernel - dequantize_q8_0_kernel<<< 512, 256 >>>( - output.data_ptr(), scales_gpu.data_ptr(), qs_gpu.data_ptr(), num_blocks, 32); + switch (target_dtype) { + case torch::kFloat16: + dequantize_q8_0_fp16_kernel<<<512, 256>>>(data_gpu.data_ptr(), (__half*)output.data_ptr(), blk_size, ele_per_blk, num_blocks); + break; + case torch::kBFloat16: + dequantize_q8_0_bf16_kernel<<<512, 256>>>(data_gpu.data_ptr(), (nv_bfloat16*)output.data_ptr(), blk_size, ele_per_blk, num_blocks); + break; + case torch::kFloat32: + dequantize_q8_0_fp32_kernel<<<512, 256>>>(data_gpu.data_ptr(), output.data_ptr(), blk_size, ele_per_blk, num_blocks); + break; + default: + printf("target type not support\n"); + exit(0); + } cudaDeviceSynchronize(); return output; } -torch::Tensor dequantize_q6_k(torch::Tensor data, int blk_size, torch::Device device) { +torch::Tensor dequantize_q6_k(const int8_t* data, const int num_bytes, const int blk_size, const int ele_per_blk, const torch::Device device, const torch::Dtype target_dtype) { // data.numel%blk_size should be 0, else raise err - int num_blocks = data.numel() / blk_size; + int num_blocks = num_bytes / blk_size; const at::cuda::OptionalCUDAGuard device_guard(device); auto options = torch::TensorOptions().dtype(torch::kInt8).device(device).memory_format(torch::MemoryFormat::Contiguous); - auto data_gpu = torch::empty({data.numel()}, options); + auto data_gpu = torch::empty({num_bytes}, options); - data_gpu.copy_(data, false); + cudaMemcpy(data_gpu.data_ptr(), data, num_bytes, cudaMemcpyHostToDevice); + //data_gpu.copy_(data, false); // Create output tensor - auto output = torch::zeros({num_blocks, 256}, torch::dtype(torch::kFloat32).device(device)); - - // Launch kernel - dequantize_q6_k_kernel<<< 512, 256 >>>(data_gpu.data_ptr(), output.data_ptr(), blk_size, num_blocks); - // dequantize_q6_k_kernel<<< 512, 256 >>>(data_gpu.data_ptr(), output.data_ptr(), 256, num_blocks); + auto output = torch::zeros({num_blocks, 256}, torch::dtype(target_dtype).device(device)); + switch (target_dtype) { + case torch::kFloat16: + dequantize_q6_k_fp16_kernel<<<512, 256>>>(data_gpu.data_ptr(), (__half*)output.data_ptr(), blk_size, ele_per_blk, num_blocks); + break; + case torch::kBFloat16: + dequantize_q6_k_bf16_kernel<<<512, 256>>>(data_gpu.data_ptr(), (nv_bfloat16*)output.data_ptr(), blk_size, ele_per_blk, num_blocks); + break; + case torch::kFloat32: + dequantize_q6_k_fp32_kernel<<<512, 256>>>(data_gpu.data_ptr(), output.data_ptr(), blk_size, ele_per_blk, num_blocks); + break; + default: + printf("target type not support\n"); + exit(0); + } cudaDeviceSynchronize(); return output; } -torch::Tensor dequantize_q5_k(torch::Tensor data, int blk_size, torch::Device device) { - int num_blocks = data.numel() / blk_size; +torch::Tensor dequantize_q5_k(const int8_t* data, const int num_bytes, const int blk_size, const int ele_per_blk, const torch::Device device, const torch::Dtype target_dtype) { + int num_blocks = num_bytes / blk_size; const at::cuda::OptionalCUDAGuard device_guard(device); auto options = torch::TensorOptions().dtype(torch::kInt8).device(device).memory_format(torch::MemoryFormat::Contiguous); - auto data_gpu = torch::empty({data.numel()}, options); + auto data_gpu = torch::empty({num_bytes}, options); - data_gpu.copy_(data, false); + cudaMemcpy(data_gpu.data_ptr(), data, num_bytes, cudaMemcpyHostToDevice); + //data_gpu.copy_(data, false); // Create output tensor - auto output = torch::zeros({num_blocks, 256}, torch::dtype(torch::kFloat32).device(device)); - - // Launch kernel - dequantize_q5_k_kernel<<< 512, 256 >>>(data_gpu.data_ptr(), output.data_ptr(), blk_size, num_blocks); + auto output = torch::zeros({num_blocks, 256}, torch::dtype(target_dtype).device(device)); + switch (target_dtype) { + case torch::kFloat16: + dequantize_q5_k_fp16_kernel<<<512, 256>>>(data_gpu.data_ptr(), (__half*)output.data_ptr(), blk_size, ele_per_blk, num_blocks); + break; + case torch::kBFloat16: + dequantize_q5_k_bf16_kernel<<<512, 256>>>(data_gpu.data_ptr(), (nv_bfloat16*)output.data_ptr(), blk_size, ele_per_blk, num_blocks); + break; + case torch::kFloat32: + dequantize_q5_k_fp32_kernel<<<512, 256>>>(data_gpu.data_ptr(), output.data_ptr(), blk_size, ele_per_blk, num_blocks); + break; + default: + printf("target type not support\n"); + exit(0); + } cudaDeviceSynchronize(); return output; } -torch::Tensor dequantize_q4_k(torch::Tensor data, int blk_size, torch::Device device) { +torch::Tensor dequantize_q4_k(const int8_t* data, const int num_bytes, const int blk_size, const int ele_per_blk, const torch::Device device, const torch::Dtype target_dtype) { // data.numel%blk_size should be 0, else raise err - int num_blocks = data.numel() / blk_size; + int num_blocks = num_bytes / blk_size; const at::cuda::OptionalCUDAGuard device_guard(device); auto options = torch::TensorOptions().dtype(torch::kInt8).device(device).memory_format(torch::MemoryFormat::Contiguous); - auto data_gpu = torch::empty({data.numel()}, options); + auto data_gpu = torch::empty({num_bytes}, options); - data_gpu.copy_(data, false); + cudaMemcpy(data_gpu.data_ptr(), data, num_bytes, cudaMemcpyHostToDevice); + //data_gpu.copy_(data, false); // Create output tensor - auto output = torch::zeros({num_blocks, 256}, torch::dtype(torch::kFloat32).device(device)); - - // Launch kernel - dequantize_q4_k_kernel<<< 512, 256 >>>(data_gpu.data_ptr(), output.data_ptr(), 256, num_blocks); + auto output = torch::zeros({num_blocks, 256}, torch::dtype(target_dtype).device(device)); + switch (target_dtype) { + case torch::kFloat16: + dequantize_q4_k_fp16_kernel<<<512, 256>>>(data_gpu.data_ptr(), (__half*)output.data_ptr(), blk_size, ele_per_blk, num_blocks); + break; + case torch::kBFloat16: + dequantize_q4_k_bf16_kernel<<<512, 256>>>(data_gpu.data_ptr(), (nv_bfloat16*)output.data_ptr(), blk_size, ele_per_blk, num_blocks); + break; + case torch::kFloat32: + dequantize_q4_k_fp32_kernel<<<512, 256>>>(data_gpu.data_ptr(), output.data_ptr(), blk_size, ele_per_blk, num_blocks); + break; + default: + printf("target type not support\n"); + exit(0); + } cudaDeviceSynchronize(); return output; } -torch::Tensor dequantize_q3_k(torch::Tensor data, int blk_size, torch::Device device) { - int num_blocks = data.numel() / blk_size; +torch::Tensor dequantize_q3_k(const int8_t* data, const int num_bytes, const int blk_size, const int ele_per_blk, const torch::Device device, const torch::Dtype target_dtype) { + int num_blocks = num_bytes / blk_size; const at::cuda::OptionalCUDAGuard device_guard(device); auto options = torch::TensorOptions().dtype(torch::kInt8).device(device).memory_format(torch::MemoryFormat::Contiguous); - auto data_gpu = torch::empty({data.numel()}, options); + auto data_gpu = torch::empty({num_bytes}, options); - data_gpu.copy_(data, false); + cudaMemcpy(data_gpu.data_ptr(), data, num_bytes, cudaMemcpyHostToDevice); + //data_gpu.copy_(data, false); // Create output tensor - auto output = torch::zeros({num_blocks, 256}, torch::dtype(torch::kFloat32).device(device)); - - // Launch kernel - dequantize_q3_k_kernel<<< 512, 256 >>>(data_gpu.data_ptr(), output.data_ptr(), blk_size, num_blocks); + auto output = torch::zeros({num_blocks, 256}, torch::dtype(target_dtype).device(device)); + switch (target_dtype) { + case torch::kFloat16: + dequantize_q3_k_fp16_kernel<<<512, 256>>>(data_gpu.data_ptr(), (__half*)output.data_ptr(), blk_size, ele_per_blk, num_blocks); + break; + case torch::kBFloat16: + dequantize_q3_k_bf16_kernel<<<512, 256>>>(data_gpu.data_ptr(), (nv_bfloat16*)output.data_ptr(), blk_size, ele_per_blk, num_blocks); + break; + case torch::kFloat32: + dequantize_q3_k_fp32_kernel<<<512, 256>>>(data_gpu.data_ptr(), output.data_ptr(), blk_size, ele_per_blk, num_blocks); + break; + default: + printf("target type not support\n"); + exit(0); + } cudaDeviceSynchronize(); return output; } -torch::Tensor dequantize_q2_k(torch::Tensor data, int blk_size, torch::Device device) { - int num_blocks = data.numel() / blk_size; +torch::Tensor dequantize_q2_k(const int8_t* data, const int num_bytes, const int blk_size, const int ele_per_blk, const torch::Device device, const torch::Dtype target_dtype) { + int num_blocks = num_bytes / blk_size; const at::cuda::OptionalCUDAGuard device_guard(device); auto options = torch::TensorOptions().dtype(torch::kInt8).device(device).memory_format(torch::MemoryFormat::Contiguous); - auto data_gpu = torch::empty({data.numel()}, options); + auto data_gpu = torch::empty({num_bytes}, options); - data_gpu.copy_(data, false); + cudaMemcpy(data_gpu.data_ptr(), data, num_bytes, cudaMemcpyHostToDevice); + //data_gpu.copy_(data, false); // Create output tensor - auto output = torch::zeros({num_blocks, 256}, torch::dtype(torch::kFloat32).device(device)); - - // Launch kernel - dequantize_q2_k_kernel<<< 512, 256 >>>(data_gpu.data_ptr(), output.data_ptr(), blk_size, num_blocks); + auto output = torch::zeros({num_blocks, 256}, torch::dtype(target_dtype).device(device)); + switch (target_dtype) { + case torch::kFloat16: + dequantize_q2_k_fp16_kernel<<<512, 256>>>(data_gpu.data_ptr(), (__half*)output.data_ptr(), blk_size, ele_per_blk, num_blocks); + break; + case torch::kBFloat16: + dequantize_q2_k_bf16_kernel<<<512, 256>>>(data_gpu.data_ptr(), (nv_bfloat16*)output.data_ptr(), blk_size, ele_per_blk, num_blocks); + break; + case torch::kFloat32: + dequantize_q2_k_fp32_kernel<<<512, 256>>>(data_gpu.data_ptr(), output.data_ptr(), blk_size, ele_per_blk, num_blocks); + break; + default: + printf("target type not support\n"); + exit(0); + } cudaDeviceSynchronize(); return output; } -torch::Tensor dequantize_iq4_xs(torch::Tensor data, int blk_size, torch::Device device) { - int num_blocks = data.numel() / blk_size; +torch::Tensor dequantize_iq4_xs(const int8_t* data, const int num_bytes, const int blk_size, const int ele_per_blk, const torch::Device device, const torch::Dtype target_dtype) { + int num_blocks = num_bytes / blk_size; const at::cuda::OptionalCUDAGuard device_guard(device); auto options = torch::TensorOptions().dtype(torch::kInt8).device(device).memory_format(torch::MemoryFormat::Contiguous); - auto data_gpu = torch::empty({data.numel()}, options); + auto data_gpu = torch::empty({num_bytes}, options); - data_gpu.copy_(data, false); + cudaMemcpy(data_gpu.data_ptr(), data, num_bytes, cudaMemcpyHostToDevice); + //data_gpu.copy_(data, false); // Create output tensor - auto output = torch::zeros({num_blocks, 256}, torch::dtype(torch::kFloat32).device(device)); - - // Launch kernel - dequantize_iq4_xs_kernel<<< 512, 256 >>>(data_gpu.data_ptr(), output.data_ptr(), blk_size, num_blocks); + auto output = torch::zeros({num_blocks, 256}, torch::dtype(target_dtype).device(device)); + switch (target_dtype) { + case torch::kFloat16: + dequantize_iq4_xs_fp16_kernel<<<512, 256>>>(data_gpu.data_ptr(), (__half*)output.data_ptr(), blk_size, ele_per_blk, num_blocks); + break; + case torch::kBFloat16: + dequantize_iq4_xs_bf16_kernel<<<512, 256>>>(data_gpu.data_ptr(), (nv_bfloat16*)output.data_ptr(), blk_size, ele_per_blk, num_blocks); + break; + case torch::kFloat32: + dequantize_iq4_xs_fp32_kernel<<<512, 256>>>(data_gpu.data_ptr(), output.data_ptr(), blk_size, ele_per_blk, num_blocks); + break; + default: + printf("target type not support\n"); + exit(0); + } cudaDeviceSynchronize(); return output; } diff --git a/ktransformers/ktransformers_ext/cuda/custom_gguf/ops.h b/ktransformers/ktransformers_ext/cuda/custom_gguf/ops.h index 666d455..a52db2d 100644 --- a/ktransformers/ktransformers_ext/cuda/custom_gguf/ops.h +++ b/ktransformers/ktransformers_ext/cuda/custom_gguf/ops.h @@ -13,10 +13,10 @@ #include #include -torch::Tensor dequantize_q8_0(torch::Tensor data, int blk_size, torch::Device device); -torch::Tensor dequantize_q6_k(torch::Tensor data, int blk_size, torch::Device device); -torch::Tensor dequantize_q5_k(torch::Tensor data, int blk_size, torch::Device device); -torch::Tensor dequantize_q4_k(torch::Tensor data, int blk_size, torch::Device device); -torch::Tensor dequantize_q3_k(torch::Tensor data, int blk_size, torch::Device device); -torch::Tensor dequantize_q2_k(torch::Tensor data, int blk_size, torch::Device device); -torch::Tensor dequantize_iq4_xs(torch::Tensor data, int blk_size, torch::Device device); +torch::Tensor dequantize_q8_0(const int8_t* data, const int num_bytes, const int blk_size, const int ele_per_blk, const torch::Device device, const torch::ScalarType target_dtype); +torch::Tensor dequantize_q6_k(const int8_t* data, const int num_bytes, const int blk_size, const int ele_per_blk, const torch::Device device, const torch::ScalarType target_dtype); +torch::Tensor dequantize_q5_k(const int8_t* data, const int num_bytes, const int blk_size, const int ele_per_blk, const torch::Device device, const torch::ScalarType target_dtype); +torch::Tensor dequantize_q4_k(const int8_t* data, const int num_bytes, const int blk_size, const int ele_per_blk, const torch::Device device, const torch::ScalarType target_dtype); +torch::Tensor dequantize_q3_k(const int8_t* data, const int num_bytes, const int blk_size, const int ele_per_blk, const torch::Device device, const torch::ScalarType target_dtype); +torch::Tensor dequantize_q2_k(const int8_t* data, const int num_bytes, const int blk_size, const int ele_per_blk, const torch::Device device, const torch::ScalarType target_dtype); +torch::Tensor dequantize_iq4_xs(const int8_t* data, const int num_bytes, const int blk_size, const int ele_per_blk, const torch::Device device, const torch::ScalarType target_dtype); diff --git a/ktransformers/ktransformers_ext/cuda/test_dequant.py b/ktransformers/ktransformers_ext/cuda/test_dequant.py new file mode 100644 index 0000000..abca745 --- /dev/null +++ b/ktransformers/ktransformers_ext/cuda/test_dequant.py @@ -0,0 +1,16 @@ +import os +import sys +sys.path.insert(0,"/home/zbx/ktransformers") +from ktransformers.util.custom_gguf import GGUFLoader +import torch + +gguf_loader_1 = GGUFLoader("/mnt/data/model/DeepseekV3-q4km-gguf") +gguf_loader_2 = GGUFLoader("/mnt/data/chenht/model/gguf_for_ktransformers/DeepSeek-V3-bf16/") + +torch.set_default_dtype(torch.bfloat16) + +tensor_1 = gguf_loader_1.load_gguf_tensor("blk.0.attn_kv_a_mqa.weight", "cuda") +tensor_2 = gguf_loader_2.load_gguf_tensor("blk.0.attn_kv_a_mqa.weight", "cuda") + +print(tensor_1[0, -64:]) +print(tensor_2[0, -64:]) \ No newline at end of file diff --git a/ktransformers/ktransformers_ext/operators/custom_marlin/quantize/utils/marlin_utils.py b/ktransformers/ktransformers_ext/operators/custom_marlin/quantize/utils/marlin_utils.py index accbc00..fadfb11 100644 --- a/ktransformers/ktransformers_ext/operators/custom_marlin/quantize/utils/marlin_utils.py +++ b/ktransformers/ktransformers_ext/operators/custom_marlin/quantize/utils/marlin_utils.py @@ -90,7 +90,7 @@ def marlin_quantize( assert group_size <= size_k # Quantize (and apply act_order if provided) - w_ref, q_w, s, g_idx, rand_perm = quantize_weights(w, num_bits, group_size, + q_w, s, g_idx, rand_perm = quantize_weights(w, num_bits, group_size, act_order) # For act_order, sort the "weights" and "g_idx" so that group ids are @@ -107,7 +107,7 @@ def marlin_quantize( marlin_scale_perm_single[num_bits]) # Create result - res_list = [w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, rand_perm] + res_list = [marlin_q_w, marlin_s, g_idx, sort_indices, rand_perm] for i in range(len(res_list)): res_list[i] = res_list[i].to(w.device) diff --git a/ktransformers/ktransformers_ext/operators/custom_marlin/quantize/utils/quant_utils.py b/ktransformers/ktransformers_ext/operators/custom_marlin/quantize/utils/quant_utils.py index b3a0ba5..de73667 100644 --- a/ktransformers/ktransformers_ext/operators/custom_marlin/quantize/utils/quant_utils.py +++ b/ktransformers/ktransformers_ext/operators/custom_marlin/quantize/utils/quant_utils.py @@ -11,8 +11,7 @@ def get_pack_factor(num_bits): return 32 // num_bits -def permute_rows(q_w: torch.Tensor, w_ref: torch.Tensor, group_size: int): - assert q_w.shape == w_ref.shape +def permute_rows(q_w: torch.Tensor, group_size: int): orig_device = q_w.device k_size, _ = q_w.shape @@ -26,10 +25,8 @@ def permute_rows(q_w: torch.Tensor, w_ref: torch.Tensor, group_size: int): g_idx = g_idx[rand_perm].contiguous() q_w = q_w[rand_perm, :].contiguous() - w_ref = w_ref[rand_perm, :].contiguous() return ( - w_ref.to(device=orig_device), q_w.to(device=orig_device), g_idx.to(device=orig_device), rand_perm.to(device=orig_device), @@ -69,9 +66,6 @@ def quantize_weights(w: torch.Tensor, num_bits: int, group_size: int, q_w += half_q_val q_w = torch.clamp(q_w, 0, max_q_val) - # Compute ref (dequantized) - w_ref = (q_w - half_q_val).half() * s - # Restore original shapes if group_size < size_k: @@ -82,7 +76,6 @@ def quantize_weights(w: torch.Tensor, num_bits: int, group_size: int, return w q_w = reshape_w(q_w) - w_ref = reshape_w(w_ref) s = s.reshape((-1, size_n)).contiguous() @@ -95,10 +88,9 @@ def quantize_weights(w: torch.Tensor, num_bits: int, group_size: int, ), "For act_order, groupsize = {} must be less than size_k = {}".format( group_size, size_k) - w_ref, q_w, g_idx, rand_perm = permute_rows(q_w, w_ref, group_size) + q_w, g_idx, rand_perm = permute_rows(q_w, group_size) return ( - w_ref.to(device=orig_device), q_w.to(device=orig_device), s.to(device=orig_device), g_idx.to(device=orig_device), diff --git a/ktransformers/local_chat.py b/ktransformers/local_chat.py index fb59a17..d5e74de 100644 --- a/ktransformers/local_chat.py +++ b/ktransformers/local_chat.py @@ -168,10 +168,7 @@ def local_chat( if mode == 'long_context': assert Config().long_context_config['max_seq_len'] > input_tensor.shape[1] + max_new_tokens, \ "please change max_seq_len in ~/.ktransformers/config.yaml" - torch.set_default_dtype( - torch.bfloat16 - ) # TODO: Remove this, replace dtype using config - + if system != "Windows" and (config.architectures[0] == "DeepseekV2ForCausalLM" or "DeepseekV3ForCausalLM") and flashinfer_enabled: generated = prefill_and_generate( model, tokenizer, input_tensor.cuda(), max_new_tokens, use_cuda_graph, mode = mode, force_think = force_think, diff --git a/ktransformers/models/modeling_deepseek.py b/ktransformers/models/modeling_deepseek.py index 692020d..e14a521 100644 --- a/ktransformers/models/modeling_deepseek.py +++ b/ktransformers/models/modeling_deepseek.py @@ -1742,8 +1742,7 @@ class DeepseekV2ForCausalLM(DeepseekV2PreTrainedModel): ) hidden_states = outputs[0] - logits = self.lm_head(hidden_states) - logits = logits[:,-1,:].unsqueeze(0).float() + logits = self.lm_head(hidden_states[:,-1:,:]).float() loss = None if labels is not None: diff --git a/ktransformers/models/modeling_deepseek_v3.py b/ktransformers/models/modeling_deepseek_v3.py index 277258a..952eed7 100644 --- a/ktransformers/models/modeling_deepseek_v3.py +++ b/ktransformers/models/modeling_deepseek_v3.py @@ -1699,7 +1699,7 @@ class DeepseekV3ForCausalLM(DeepseekV3PreTrainedModel): ) hidden_states = outputs[0] - logits = self.lm_head(hidden_states.to(self.lm_head.weight.device)) + logits = self.lm_head(hidden_states[:,-1:,:]) logits = logits.float() loss = None diff --git a/ktransformers/operators/linear.py b/ktransformers/operators/linear.py index e778102..96d3578 100644 --- a/ktransformers/operators/linear.py +++ b/ktransformers/operators/linear.py @@ -21,6 +21,7 @@ from ktransformers.ktransformers_ext.operators.custom_marlin.quantize.utils.marl MarlinWorkspace, marlin_quantize, GPTQ_MARLIN_MIN_THREAD_N, + GPTQ_MARLIN_MIN_THREAD_K, GPTQ_MARLIN_MAX_PARALLEL, ) from ktransformers.operators.base_operator import BaseInjectedModule @@ -65,6 +66,8 @@ class KLinearBase(ABC): self.in_features = self.gguf_loader.tensor_info[key + ".weight"]["shape"][0] self.out_features = self.gguf_loader.tensor_info[key + ".weight"]["shape"][1] + self.loaded = False # for lm_head pre-load, TODO: use new way to do lm_head pre-load when layer wise prefill. + @abstractmethod def forward(self, x: torch.Tensor) -> torch.Tensor: pass @@ -141,6 +144,7 @@ class KLinearTorch(KLinearBase): return x def load(self, w: dict | nn.Parameter | tuple | None = None, device: str|None = None): + if self.loaded: return if device is None: device = self.device if w is None: w = self.load_weight(device=device) # else: self.out_features = w.shape[0], self.in_features = w.shape[1] @@ -164,6 +168,7 @@ class KLinearTorch(KLinearBase): self.weight = self.weight.to(device) if self.has_bias: self.bias = self.bias.to(device) + self.loaded = True def unload(self): if self.weight is not None: @@ -251,20 +256,36 @@ class KLinearMarlin(KLinearBase): self.group_size = group_size self.act_order = act_order self.is_k_full = is_k_full + self.padding = False + self.orin_in_features = self.in_features + self.orin_out_features = self.out_features + if self.in_features%GPTQ_MARLIN_MIN_THREAD_K!=0 or self.out_features%GPTQ_MARLIN_MIN_THREAD_K!=0: + #print(f"warning!, in_features={in_features} or out_features={out_features} is undivisible by GPTQ_MARLIN_MIN_THREAD_K={GPTQ_MARLIN_MIN_THREAD_K} and GPTQ_MARLIN_MIN_THREAD_N={GPTQ_MARLIN_MIN_THREAD_N}, padding") + self.padding = True + self.in_features = (self.in_features+GPTQ_MARLIN_MIN_THREAD_K-1)//GPTQ_MARLIN_MIN_THREAD_K*GPTQ_MARLIN_MIN_THREAD_K + self.out_features = (self.out_features+GPTQ_MARLIN_MIN_THREAD_N-1)//GPTQ_MARLIN_MIN_THREAD_N*GPTQ_MARLIN_MIN_THREAD_N + #print(f"After padding: in_features={in_features}, out_features={out_features}") + + self.k = self.in_features + self.n = self.out_features def load(self, w: dict | nn.Parameter | tuple | None = None, device: str|None = None): + if self.loaded: return if device is None: device = self.device assert device.lower() != "cpu", "Marlin quantized linear only supports GPU device" + + #if self.in_features * self.out_features: if w is None: w = self.load_weight(device=device) if isinstance(w, nn.Parameter): # pad weight - weight = w.view(self.out_features, self.in_features).T + weight = w.view(self.orin_out_features, self.orin_in_features).T self.has_bias = False elif isinstance(w, tuple): w = list(w) - weight = w[0].view(self.out_features, self.in_features).T + weight = w[0].view(self.orin_out_features, self.orin_in_features).T + self.bias = w[1].view(self.orin_out_features) self.bias = w[1] self.has_bias = True else: @@ -272,8 +293,14 @@ class KLinearMarlin(KLinearBase): weight = weight.to(device) if self.has_bias: self.bias = self.bias.to(device) + + if self.padding: + padded_weight = torch.zeros(self.in_features, self.out_features, device=self.device) + padded_weight[:self.orin_in_features, :self.orin_out_features] = weight + weight = padded_weight + # Pack Marlin linear - w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, _ = marlin_quantize( + marlin_q_w, marlin_s, g_idx, sort_indices, _ = marlin_quantize( weight, self.num_bits, self.group_size, self.act_order ) self.workspace = MarlinWorkspace( @@ -286,6 +313,7 @@ class KLinearMarlin(KLinearBase): self.sort_indices = sort_indices self.k = weight.shape[0] self.n = weight.shape[1] + self.loaded = True def forward(self, x: torch.Tensor) -> torch.Tensor: # Only support input x as BF16 and FP16 @@ -293,6 +321,11 @@ class KLinearMarlin(KLinearBase): orig_shape = list(x.shape) orig_dtype = x.dtype x = x.reshape(-1, orig_shape[-1]) + x = x.reshape(-1, x.shape[-1]) + if self.padding: + padding_input=torch.empty(x.shape[0], self.in_features, device=x.device, dtype=x.dtype) + padding_input[:,:self.orin_in_features] = x + x = padding_input marlin_s = self.marlin_s.to(x.dtype) x = KTransformersOps.gptq_marlin_gemm( x, @@ -307,6 +340,11 @@ class KLinearMarlin(KLinearBase): x.shape[-1], self.is_k_full, ) + if self.padding: + x = x[:,:self.orin_out_features] + orig_shape[-1] = self.orin_out_features + else: + orig_shape[-1] = self.out_features if self.has_bias: x = x + self.bias orig_shape[-1] = self.n @@ -450,24 +488,13 @@ class KTransformersLinear(BaseInjectedModule, KLinearBase): # build all the linear operators if prefill_op is not None: assert prefill_op in LINEAR_MAP, f"linear_type {prefill_op} not supported" - if prefill_op == "KLinearMarlin" and (orig_module.in_features%GPTQ_MARLIN_MIN_THREAD_N!=0 or orig_module.out_features%GPTQ_MARLIN_MIN_THREAD_N!=0): - print(f"This linear module's in_features or out_features is not divisible by GPTQ_MARLIN_MIN_THREAD_N({GPTQ_MARLIN_MIN_THREAD_N}), using KLinearTorch instead.") - print(f"module info: key:{key} orig_module:{orig_module}") - self.prefill_linear = KLinearTorch(key, gguf_loader, config, orig_module, prefill_device, **kwargs) - else: - self.prefill_linear = LINEAR_MAP[prefill_op](key, gguf_loader, config, orig_module, prefill_device, **kwargs) + self.prefill_linear = LINEAR_MAP[prefill_op](key, gguf_loader, config, orig_module, prefill_device, **kwargs) else: self.prefill_linear = None if generate_op is not None: assert generate_op in LINEAR_MAP, f"linear_type {generate_op} not supported" - if generate_op == "KLinearMarlin" and (orig_module.in_features%GPTQ_MARLIN_MIN_THREAD_N!=0 or orig_module.out_features%GPTQ_MARLIN_MIN_THREAD_N!=0): - print(f"This linear module's in_features or out_features is not divisible by GPTQ_MARLIN_MIN_THREAD_N({GPTQ_MARLIN_MIN_THREAD_N}), using KLinearTorch instead.") - print(f"module info: key:{key} orig_module:{orig_module}") - self.generate_op = "KLinearTorch" - self.generate_linear = KLinearTorch(key, gguf_loader, config, orig_module, generate_device, **kwargs) - else: - self.generate_linear = LINEAR_MAP[generate_op](key, gguf_loader, config, orig_module, generate_device, **kwargs) + self.generate_linear = LINEAR_MAP[generate_op](key, gguf_loader, config, orig_module, generate_device, **kwargs) else: self.generate_linear = None self.mode = InferenceState.UNLOAD diff --git a/ktransformers/optimize/optimize.py b/ktransformers/optimize/optimize.py index 32eab01..331e6cf 100644 --- a/ktransformers/optimize/optimize.py +++ b/ktransformers/optimize/optimize.py @@ -126,6 +126,8 @@ def optimize_and_load_gguf(module: nn.Module, rule_file: str, gguf_path: str, mo gguf_loader=GGUFLoader(gguf_path) with torch.device("meta"): inject(module, optimize_config, model_config, gguf_loader) + # pre load lm_head because its big inter result + load_weights(module.lm_head, gguf_loader, "lm_head.") load_weights(module, gguf_loader) module.gguf_loader = gguf_loader del_meta(module) diff --git a/ktransformers/optimize/optimize_rules/DeepSeek-V2-Chat-multi-gpu-4.yaml b/ktransformers/optimize/optimize_rules/DeepSeek-V2-Chat-multi-gpu-4.yaml index a87a30c..66a420a 100644 --- a/ktransformers/optimize/optimize_rules/DeepSeek-V2-Chat-multi-gpu-4.yaml +++ b/ktransformers/optimize/optimize_rules/DeepSeek-V2-Chat-multi-gpu-4.yaml @@ -219,8 +219,20 @@ kwargs: generate_device: "cuda:2" prefill_device: "cuda:2" + - match: - name: "(^model\\.layers\\.([5][0-9]|[4][5-9])\\.)|(^model.norm)|(^lm_head)" + name: "^lm_head" + class: torch.nn.Linear + replace: + class: ktransformers.operators.linear.KTransformersLinear + kwargs: + generate_device: "cuda:3" + prefill_device: "cuda:3" + generate_op: "KLinearMarlin" + prefill_op: "KLinearTorch" + +- match: + name: "(^model\\.layers\\.([5][0-9]|[4][5-9])\\.)|(^model.norm)" replace: class: "default" kwargs: diff --git a/ktransformers/optimize/optimize_rules/DeepSeek-V2-Chat-multi-gpu.yaml b/ktransformers/optimize/optimize_rules/DeepSeek-V2-Chat-multi-gpu.yaml index 269257e..f409376 100644 --- a/ktransformers/optimize/optimize_rules/DeepSeek-V2-Chat-multi-gpu.yaml +++ b/ktransformers/optimize/optimize_rules/DeepSeek-V2-Chat-multi-gpu.yaml @@ -118,7 +118,18 @@ prefill_device: "cuda:0" - match: - name: "(^model\\.layers\\.([345][0-9])\\.)|(model.norm)|(lm_head)" + name: "^lm_head" + class: torch.nn.Linear + replace: + class: ktransformers.operators.linear.KTransformersLinear + kwargs: + generate_device: "cuda:1" + prefill_device: "cuda:1" + generate_op: "KLinearMarlin" + prefill_op: "KLinearTorch" + +- match: + name: "(^model\\.layers\\.([345][0-9])\\.)|(model.norm)" replace: class: "default" kwargs: diff --git a/ktransformers/optimize/optimize_rules/DeepSeek-V2-Chat.yaml b/ktransformers/optimize/optimize_rules/DeepSeek-V2-Chat.yaml index b115aba..7f3e44e 100644 --- a/ktransformers/optimize/optimize_rules/DeepSeek-V2-Chat.yaml +++ b/ktransformers/optimize/optimize_rules/DeepSeek-V2-Chat.yaml @@ -15,6 +15,18 @@ prefill_device: "cuda" generate_op: "KLinearMarlin" prefill_op: "KLinearTorch" + +- match: + name: "^lm_head" + class: torch.nn.Linear + replace: + class: ktransformers.operators.linear.KTransformersLinear + kwargs: + generate_device: "cuda" + prefill_device: "cuda" + generate_op: "KLinearMarlin" + prefill_op: "KLinearTorch" + - match: name: "^model\\.layers\\..*\\.mlp$" class: ktransformers.models.modeling_deepseek.DeepseekV2MoE diff --git a/ktransformers/optimize/optimize_rules/DeepSeek-V2-Lite-Chat-multi-gpu.yaml b/ktransformers/optimize/optimize_rules/DeepSeek-V2-Lite-Chat-multi-gpu.yaml index 99d01c0..158892d 100644 --- a/ktransformers/optimize/optimize_rules/DeepSeek-V2-Lite-Chat-multi-gpu.yaml +++ b/ktransformers/optimize/optimize_rules/DeepSeek-V2-Lite-Chat-multi-gpu.yaml @@ -118,7 +118,18 @@ prefill_device: "cuda:0" - match: - name: "(^model\\.layers\\.([12][0-9])\\.)|(model.norm)|(lm_head)" + name: "^lm_head" + class: torch.nn.Linear + replace: + class: ktransformers.operators.linear.KTransformersLinear + kwargs: + generate_device: "cuda:1" + prefill_device: "cuda:1" + generate_op: "KLinearMarlin" + prefill_op: "KLinearTorch" + +- match: + name: "(^model\\.layers\\.([12][0-9])\\.)|(model.norm)" replace: class: "default" kwargs: diff --git a/ktransformers/optimize/optimize_rules/DeepSeek-V2-Lite-Chat.yaml b/ktransformers/optimize/optimize_rules/DeepSeek-V2-Lite-Chat.yaml index b115aba..7f3e44e 100644 --- a/ktransformers/optimize/optimize_rules/DeepSeek-V2-Lite-Chat.yaml +++ b/ktransformers/optimize/optimize_rules/DeepSeek-V2-Lite-Chat.yaml @@ -15,6 +15,18 @@ prefill_device: "cuda" generate_op: "KLinearMarlin" prefill_op: "KLinearTorch" + +- match: + name: "^lm_head" + class: torch.nn.Linear + replace: + class: ktransformers.operators.linear.KTransformersLinear + kwargs: + generate_device: "cuda" + prefill_device: "cuda" + generate_op: "KLinearMarlin" + prefill_op: "KLinearTorch" + - match: name: "^model\\.layers\\..*\\.mlp$" class: ktransformers.models.modeling_deepseek.DeepseekV2MoE diff --git a/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu-4.yaml b/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu-4.yaml index 84ab801..03c85a0 100644 --- a/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu-4.yaml +++ b/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu-4.yaml @@ -188,7 +188,7 @@ # !!!Do remember 'close' cuda graph if you are using marlin expert.!!! # !!!KExpertsTorch is untested, we don't have enough VRAM.!!! -# # GPU 0: layers 3–4 +# GPU 0: layers 3–4 # - match: # name: "^model\\.layers\\.([3-4])\\.mlp\\.experts$" # replace: @@ -363,11 +363,20 @@ generate_device: "cuda:2" prefill_device: "cuda:2" -# don't inject lm_head if already inject marlin experts - -# For final modules (model.norm and lm_head), ensure they are on GPU 3 (as in your original config) - match: - name: "(^model\\.layers\\.(4[5-9]|5[0-9]|60)\\.)|(^model\\.norm)|(^lm_head)" + name: "^lm_head" + class: torch.nn.Linear + replace: + class: ktransformers.operators.linear.KTransformersLinear + kwargs: + generate_device: "cuda:3" + prefill_device: "cuda:3" + generate_op: "KLinearMarlin" + prefill_op: "KLinearTorch" + +# For final modules (model.norm), ensure they are on GPU 3 (as in your original config) +- match: + name: "(^model\\.layers\\.(4[5-9]|5[0-9]|60)\\.)|(^model\\.norm)" replace: class: "default" kwargs: diff --git a/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu-8.yaml b/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu-8.yaml index a10b57f..b00d2b4 100644 --- a/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu-8.yaml +++ b/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu-8.yaml @@ -713,11 +713,20 @@ generate_device: "cuda:7" prefill_device: "cuda:7" -# don't inject lm_head if already inject marlin experts - -# For final modules (model.norm and lm_head), ensure they are on GPU 7 (as in your original config) - match: - name: "(^model\\.layers\\.(4[5-9]|5[0-9]|60)\\.)|(^model\\.norm)|(^lm_head)" + name: "^lm_head" + class: torch.nn.Linear + replace: + class: ktransformers.operators.linear.KTransformersLinear + kwargs: + generate_device: "cuda:7" + prefill_device: "cuda:7" + generate_op: "KLinearMarlin" + prefill_op: "KLinearTorch" + +# For final modules (model.norm), ensure they are on GPU 7 (as in your original config) +- match: + name: "(^model\\.layers\\.(4[5-9]|5[0-9]|60)\\.)|(^model\\.norm)" replace: class: "default" kwargs: diff --git a/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu-marlin.yaml b/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu-marlin.yaml index 92571b5..6b39121 100644 --- a/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu-marlin.yaml +++ b/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu-marlin.yaml @@ -153,7 +153,18 @@ prefill_device: "cuda:0" - match: - name: "(^model\\.layers\\.([3456][0-9])\\.)|(model.norm)|(lm_head)" + name: "^lm_head" + class: torch.nn.Linear + replace: + class: ktransformers.operators.linear.KTransformersLinear + kwargs: + generate_device: "cuda:0" + prefill_device: "cuda:0" + generate_op: "KLinearMarlin" + prefill_op: "KLinearTorch" + +- match: + name: "(^model\\.layers\\.([3456][0-9])\\.)|(model.norm)" replace: class: "default" kwargs: diff --git a/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu.yaml b/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu.yaml index 06ab4db..50e282d 100644 --- a/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu.yaml +++ b/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu.yaml @@ -135,7 +135,18 @@ prefill_device: "cuda:0" - match: - name: "(^model\\.layers\\.([3456][0-9])\\.)|(model.norm)|(lm_head)" + name: "^lm_head" + class: torch.nn.Linear + replace: + class: ktransformers.operators.linear.KTransformersLinear + kwargs: + generate_device: "cuda:1" + prefill_device: "cuda:1" + generate_op: "KLinearMarlin" + prefill_op: "KLinearTorch" + +- match: + name: "(^model\\.layers\\.([3456][0-9])\\.)|(model.norm)" replace: class: "default" kwargs: diff --git a/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat.yaml b/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat.yaml index 7a44c5d..6fb6586 100644 --- a/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat.yaml +++ b/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat.yaml @@ -5,6 +5,18 @@ kwargs: generate_device: "cuda" prefill_device: "cuda" + +- match: + name: "^lm_head$" # regular expression + class: torch.nn.Linear # only match modules matching name and class simultaneously + replace: + class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types + kwargs: + generate_device: "cuda" + prefill_device: "cuda" + generate_op: "KLinearMarlin" + prefill_op: "KLinearTorch" + - match: name: "^model\\.layers\\.(?!.*self_attn\\.kv_b_proj).*$" # regular expression class: torch.nn.Linear # only match modules matching name and class simultaneously diff --git a/ktransformers/optimize/optimize_rules/Mixtral.yaml b/ktransformers/optimize/optimize_rules/Mixtral.yaml index 7d48812..80a346a 100644 --- a/ktransformers/optimize/optimize_rules/Mixtral.yaml +++ b/ktransformers/optimize/optimize_rules/Mixtral.yaml @@ -15,6 +15,16 @@ prefill_device: "cuda" generate_op: "KLinearMarlin" prefill_op: "KLinearTorch" +- match: + name: "^lm_head" + class: torch.nn.Linear + replace: + class: ktransformers.operators.linear.KTransformersLinear + kwargs: + generate_device: "cuda" + prefill_device: "cuda" + generate_op: "KLinearMarlin" + prefill_op: "KLinearTorch" - match: name: "^model\\.layers\\..*\\.block_sparse_moe$" class: ktransformers.models.modeling_mixtral.MixtralSparseMoeBlock diff --git a/ktransformers/optimize/optimize_rules/Qwen2-57B-A14B-Instruct-multi-gpu.yaml b/ktransformers/optimize/optimize_rules/Qwen2-57B-A14B-Instruct-multi-gpu.yaml index da4fb4a..da01c82 100644 --- a/ktransformers/optimize/optimize_rules/Qwen2-57B-A14B-Instruct-multi-gpu.yaml +++ b/ktransformers/optimize/optimize_rules/Qwen2-57B-A14B-Instruct-multi-gpu.yaml @@ -77,9 +77,19 @@ kwargs: generate_device: "cpu" prefill_device: "cpu" +- match: + name: "^lm_head" + class: torch.nn.Linear + replace: + class: ktransformers.operators.linear.KTransformersLinear + kwargs: + generate_device: "cuda:1" + prefill_device: "cuda:1" + generate_op: "KLinearMarlin" + prefill_op: "KLinearTorch" - match: - name: "(^model.norm)|(^lm_head)" + name: "(^model.norm)" replace: class: "default" kwargs: diff --git a/ktransformers/optimize/optimize_rules/Qwen2-57B-A14B-Instruct.yaml b/ktransformers/optimize/optimize_rules/Qwen2-57B-A14B-Instruct.yaml index 0cc2edf..38e9e73 100644 --- a/ktransformers/optimize/optimize_rules/Qwen2-57B-A14B-Instruct.yaml +++ b/ktransformers/optimize/optimize_rules/Qwen2-57B-A14B-Instruct.yaml @@ -15,6 +15,16 @@ prefill_device: "cuda" generate_op: "KLinearMarlin" prefill_op: "KLinearTorch" +- match: + name: "^lm_head" + class: torch.nn.Linear + replace: + class: ktransformers.operators.linear.KTransformersLinear + kwargs: + generate_device: "cuda" + prefill_device: "cuda" + generate_op: "KLinearMarlin" + prefill_op: "KLinearTorch" - match: name: "^model\\.layers\\..*\\.mlp$" class: ktransformers.models.modeling_qwen2_moe.Qwen2MoeSparseMoeBlock diff --git a/ktransformers/server/backend/interfaces/ktransformers.py b/ktransformers/server/backend/interfaces/ktransformers.py index edca541..49a3f16 100644 --- a/ktransformers/server/backend/interfaces/ktransformers.py +++ b/ktransformers/server/backend/interfaces/ktransformers.py @@ -25,10 +25,10 @@ class KTransformersThreadContext(TransformersThreadContext): class KTransformersInterface(TransformersInterface): def __init__(self, args: ConfigArgs = default_args): self.args = args - torch.set_default_dtype(torch.bfloat16) torch.set_grad_enabled(False) self.tokenizer = AutoTokenizer.from_pretrained(args.model_dir, device=args.device, trust_remote_code=args.trust_remote_code) config = AutoConfig.from_pretrained(args.model_dir, trust_remote_code=args.trust_remote_code) + torch.set_default_dtype(config.torch_dtype) if config.architectures[0] == "Qwen2MoeForCausalLM": config._attn_implementation = "flash_attention_2" diff --git a/ktransformers/tests/mmlu_pro_test.py b/ktransformers/tests/mmlu_pro_test.py index d44be2a..27eb9b2 100644 --- a/ktransformers/tests/mmlu_pro_test.py +++ b/ktransformers/tests/mmlu_pro_test.py @@ -176,7 +176,7 @@ if __name__ == "__main__": parser.add_argument("--result", type=str, default="./mmlu_pro.json", help="Path to save the result JSON file") parser.add_argument("--log", type=str, default="./mmlu_pro.log", help="Path to save the log file") parser.add_argument("--model", type=str, default="Pro/deepseek-ai/DeepSeek-V3", help="Model name or path") - parser.add_argument("--api_url", type=str, default="http://localhost:10002/v1/chat/completions", help="API URL") + parser.add_argument("--api_url", type=str, default="http://localhost:15488/v1/chat/completions", help="API URL") # parser.add_argument("--api_url", type=str, default="https://api.siliconflow.cn/v1/chat/completions", help="API URL") args = parser.parse_args() diff --git a/ktransformers/util/custom_gguf.py b/ktransformers/util/custom_gguf.py index d054ad3..d26dc26 100644 --- a/ktransformers/util/custom_gguf.py +++ b/ktransformers/util/custom_gguf.py @@ -26,6 +26,7 @@ from enum import IntEnum import torch import KTransformersOps from .custom_loader import SafeTensorLoader +import ctypes class GGMLQuantizationType(IntEnum): F32 = 0 @@ -305,7 +306,7 @@ class GGUFLoader: data = torch.from_numpy(data) return data, ggml_type - def load_expert_tensor(self, name, data, expert_id, elements_per_expert, device = "gpu")->torch.Tensor: + def load_expert_tensor(self, name, data, expert_id, elements_per_expert, device = "cuda", target_dtype = torch.get_default_dtype())->torch.Tensor: t = self.tensor_info[name] if device.lower() == "cpu": print(f"loading expert {expert_id} of {name} with CPU") @@ -324,19 +325,21 @@ class GGUFLoader: data = data[offset: offset + block_size * blocks_per_experts] if "cuda" in device.lower(): - values = GGML_DEQUANTIZE_GPU[ggml_name](data, device) + values = GGML_DEQUANTIZE_GPU[ggml_name](data, device, target_dtype) else: values = GGML_DEQUANTIZE[ggml_name](data) - values = torch.from_numpy(values) + values = torch.from_numpy(values.copy()) values = values.view(shape[-2::-1]) return values - def load_gguf_tensor(self, name: str, device:str = "cpu")->torch.Tensor: + def load_gguf_tensor(self, name: str, device:str = "cpu", target_dtype = None)->torch.Tensor: t = self.tensor_info[name] if device.lower() == "cpu": print(f"loading {name} with CPU") + if target_dtype == None: + target_dtype = torch.get_default_dtype() shape = t["shape"] ggml_type = t["ggml_type"] @@ -348,16 +351,38 @@ class GGUFLoader: data = self.get_mmap_tensor(name) - if "cuda" in device.lower(): - values = GGML_DEQUANTIZE_GPU[ggml_name](data, device) - #values = GGML_DEQUANTIZE[ggml_name](data) - #print("load_gguf_tensor") - #values = torch.from_numpy(values).to(device = device) + block_size = GGML_BLOCK_SIZES[ggml_name] + elements_per_block = GGML_ELEMENTS_PER_BLOCK[ggml_name] + num_elements = int(np.prod(shape)) + num_blocks = num_elements // elements_per_block + + blocks_per_iter = 16384 + if num_blocks > blocks_per_iter: # dequant large tensor + values = torch.empty((num_blocks, elements_per_block), dtype=target_dtype, device=device) + for i in range( (num_blocks + blocks_per_iter - 1) // blocks_per_iter): + blocks_begin = i * blocks_per_iter + blocks_end = min(blocks_begin + blocks_per_iter, num_blocks) + if "cuda" in device.lower(): + cur_values = GGML_DEQUANTIZE_GPU[ggml_name](data[blocks_begin*block_size : blocks_end*block_size], device, target_dtype) + else: + cur_values = GGML_DEQUANTIZE[ggml_name](data[blocks_begin*block_size : blocks_end*block_size]) + cur_values = torch.from_numpy(cur_values.copy()) + + cur_values = cur_values.view(-1, elements_per_block) + if ggml_name == "BF16": + cur_values = cur_values.view(torch.bfloat16) + values[blocks_begin : blocks_end] = cur_values else: - values = GGML_DEQUANTIZE[ggml_name](data) - values = torch.from_numpy(values) + if "cuda" in device.lower(): + values = GGML_DEQUANTIZE_GPU[ggml_name](data, device) + else: + values = GGML_DEQUANTIZE[ggml_name](data) + values = torch.from_numpy(values) + if ggml_name == "BF16": values = values.view(torch.bfloat16) + + values = values.view(shape[::-1]) if "attn_q" in name and self.gguf_file_meta['general.architecture'] in ["llama"]: n_head = self.gguf_file_meta['llama.attention.head_count'] @@ -456,14 +481,15 @@ def dequantize_q2_k(data): return d * (scales & 15) * (tmp & 3) - dmin * (scales >> 4) -def dequantize_q2_k_gpu(data, device:str ="cuda"): +def dequantize_q2_k_gpu(data, device:str ="cuda", target_dtype = torch.get_default_dtype()): block_size = GGML_BLOCK_SIZES["Q2_K"] + ele_per_blk = GGML_ELEMENTS_PER_BLOCK["Q2_K"] data = np.frombuffer(data, dtype=data.dtype) device = torch.device(device) # TODO: this and from_numpy in other functions will cause a warning saying that numpy is not writable, # the best way to fix this is transfer ptr to KTransformersOps instead of Tensor. - data = torch.from_numpy(data) - return KTransformersOps.dequantize_q2_k(data, block_size, device) + c_pointer = ctypes.addressof(ctypes.cast(data.ctypes.data, ctypes.POINTER(ctypes.c_int8)).contents) + return KTransformersOps.dequantize_q2_k(c_pointer, data.size, block_size, ele_per_blk, device, target_dtype) def dequantize_q3_k(data): # C implementation @@ -507,14 +533,15 @@ def dequantize_q3_k(data): (((qs[:, 48:64] >> 6) & 3) - bits[:, 16:, 7]) ], axis=1) -def dequantize_q3_k_gpu(data, device:str ="cuda"): +def dequantize_q3_k_gpu(data, device:str ="cuda", target_dtype = torch.get_default_dtype()): block_size = GGML_BLOCK_SIZES["Q3_K"] + ele_per_blk = GGML_ELEMENTS_PER_BLOCK["Q3_K"] data = np.frombuffer(data, dtype=data.dtype) device = torch.device(device) # TODO: this and from_numpy in other functions will cause a warning saying that numpy is not writable, # the best way to fix this is transfer ptr to KTransformersOps instead of Tensor. - data = torch.from_numpy(data) - return KTransformersOps.dequantize_q3_k(data, block_size, device) + c_pointer = ctypes.addressof(ctypes.cast(data.ctypes.data, ctypes.POINTER(ctypes.c_int8)).contents) + return KTransformersOps.dequantize_q3_k(c_pointer, data.size, block_size, ele_per_blk, device, target_dtype) def dequantize_q4_k(data): # C implementation @@ -538,13 +565,15 @@ def dequantize_q4_k(data): # Dequantize final weights using scales and offsets return factors * qs2 - offsets -def dequantize_q4_k_gpu(data, device:str ="cuda"): +def dequantize_q4_k_gpu(data, device:str ="cuda", target_dtype = torch.get_default_dtype()): + block_size = GGML_BLOCK_SIZES["Q4_K"] + ele_per_blk = GGML_ELEMENTS_PER_BLOCK["Q4_K"] data = np.frombuffer(data, dtype=data.dtype) device = torch.device(device) # TODO: this and from_numpy in other functions will cause a warning saying that numpy is not writable, # the best way to fix this is transfer ptr to KTransformersOps instead of Tensor. - data = torch.from_numpy(data) - return KTransformersOps.dequantize_q4_k(data, 144, device) + c_pointer = ctypes.addressof(ctypes.cast(data.ctypes.data, ctypes.POINTER(ctypes.c_int8)).contents) + return KTransformersOps.dequantize_q4_k(c_pointer, data.size, block_size, ele_per_blk, device, target_dtype) def dequantize_q5_k(data): # C implementation @@ -602,14 +631,15 @@ def dequantize_q5_k(data): d8 * (qs_hi_4[:, 3] + (bits[:, :, 7] << 4)) - m8, ], axis=1) -def dequantize_q5_k_gpu(data, device:str ="cuda"): +def dequantize_q5_k_gpu(data, device:str ="cuda", target_dtype = torch.get_default_dtype()): block_size = GGML_BLOCK_SIZES["Q5_K"] + ele_per_blk = GGML_ELEMENTS_PER_BLOCK["Q5_K"] data = np.frombuffer(data, dtype=data.dtype) device = torch.device(device) # TODO: this and from_numpy in other functions will cause a warning saying that numpy is not writable, # the best way to fix this is transfer ptr to KTransformersOps instead of Tensor. - data = torch.from_numpy(data) - return KTransformersOps.dequantize_q5_k(data, block_size, device) + c_pointer = ctypes.addressof(ctypes.cast(data.ctypes.data, ctypes.POINTER(ctypes.c_int8)).contents) + return KTransformersOps.dequantize_q5_k(c_pointer, data.size, block_size, ele_per_blk, device, target_dtype) def dequantize_q6_k(data): # C implementation @@ -660,13 +690,14 @@ def dequantize_q6_k(data): ], axis=1) # @torch.jit.script -def dequantize_q6_k_gpu(data: np.ndarray, device:str = "cuda"): +def dequantize_q6_k_gpu(data: np.ndarray, device:str = "cuda", target_dtype = torch.get_default_dtype()): block_size = GGML_BLOCK_SIZES["Q6_K"] + ele_per_blk = GGML_ELEMENTS_PER_BLOCK["Q6_K"] device = torch.device(device) num_blocks = len(data) // block_size data = np.frombuffer(data, dtype=data.dtype) - data = torch.from_numpy(data) - return KTransformersOps.dequantize_q6_k(data, block_size, device) + c_pointer = ctypes.addressof(ctypes.cast(data.ctypes.data, ctypes.POINTER(ctypes.c_int8)).contents) + return KTransformersOps.dequantize_q6_k(c_pointer, data.size, block_size, ele_per_blk, device, target_dtype) kvalues_iq4nl = np.array([-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113], dtype=np.int8) @@ -700,13 +731,14 @@ def dequantize_iq4_xs(data): return y.flatten() -def dequantize_iq4_xs_gpu(data: np.ndarray, device:str = "cuda"): +def dequantize_iq4_xs_gpu(data: np.ndarray, device:str = "cuda", target_dtype = torch.get_default_dtype()): block_size = GGML_BLOCK_SIZES["IQ4_XS"] + ele_per_blk = GGML_ELEMENTS_PER_BLOCK["IQ4_XS"] device = torch.device(device) num_blocks = len(data) // block_size data = np.frombuffer(data, dtype=data.dtype) - data = torch.from_numpy(data) - return KTransformersOps.dequantize_iq4_xs(data, block_size, device) + c_pointer = ctypes.addressof(ctypes.cast(data.ctypes.data, ctypes.POINTER(ctypes.c_int8)).contents) + return KTransformersOps.dequantize_iq4_xs(c_pointer, data.size, block_size, ele_per_blk, device, target_dtype) def dequantize_q4_0(data): # C implementation @@ -723,7 +755,7 @@ def dequantize_q4_0(data): scales * ((qs >> 4).astype(np.int8) - 8), ], axis=1) -def dequantize_q4_0_gpu(data): +def dequantize_q4_0_gpu(data, device:str = "cuda", target_dtype = torch.get_default_dtype()): raise NotImplementedError() def dequantize_q5_0(data): @@ -747,7 +779,7 @@ def dequantize_q5_0(data): scales * x1, ], axis=1) -def dequantize_q5_0_gpu(data): +def dequantize_q5_0_gpu(data, device:str = "cuda", target_dtype = torch.get_default_dtype()): raise NotImplementedError() def dequantize_q8_0(data): @@ -759,32 +791,41 @@ def dequantize_q8_0(data): qs = np.frombuffer(data, dtype=np.int8).reshape(num_blocks, 2 + 32)[:, 2:] return scales * qs -def dequantize_q8_0_gpu(data, device:str = "cuda"): +def dequantize_q8_0_gpu(data, device:str = "cuda", target_dtype = torch.get_default_dtype()): # C struct definition # https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.h#L43 - num_blocks = len(data) // GGML_BLOCK_SIZES["Q8_0"] + + block_size = GGML_BLOCK_SIZES["Q8_0"] + ele_per_blk = GGML_ELEMENTS_PER_BLOCK["Q8_0"] device = torch.device(device) data = np.frombuffer(data, dtype=data.dtype) - data = torch.from_numpy(data) - return KTransformersOps.dequantize_q8_0(data, 34, device) + c_pointer = ctypes.addressof(ctypes.cast(data.ctypes.data, ctypes.POINTER(ctypes.c_int8)).contents) + return KTransformersOps.dequantize_q8_0(c_pointer, data.size, block_size, ele_per_blk, device, target_dtype) def dequantize_f32(data): return np.frombuffer(data, dtype=np.float32) -def dequantize_f32_gpu(data, device): +def dequantize_f32_gpu(data, device, target_dtype = torch.get_default_dtype()): data = np.frombuffer(data, dtype=np.float32) - res = torch.from_numpy(data) - res_gpu = torch.empty_like(res, device=device) + res = torch.from_numpy(data.copy()) + res_gpu = torch.empty_like(res, device=device, dtype=target_dtype) res_gpu.copy_(res) return res_gpu def dequantize_f16(data): return np.frombuffer(data, dtype=np.float16) -def dequantize_f16_gpu(data, device): +def dequantize_f16_gpu(data, device, target_dtype = torch.get_default_dtype()): data = np.frombuffer(data, dtype=np.float16) - res = torch.from_numpy(data) + res = torch.from_numpy(data.copy()) + res_gpu = torch.empty_like(res, device=device, dtype=target_dtype) + res_gpu.copy_(res) + return res_gpu + +def dequantize_bf16_gpu(data, device, target_dtype = torch.get_default_dtype()): + data = np.frombuffer(data, dtype=np.float16) + res = torch.from_numpy(data.copy()) res_gpu = torch.empty_like(res, device=device) res_gpu.copy_(res) return res_gpu @@ -807,7 +848,7 @@ GGML_DEQUANTIZE = { GGML_DEQUANTIZE_GPU = { "F32": dequantize_f32_gpu, "F16": dequantize_f16_gpu, - "BF16": dequantize_f16_gpu, + "BF16": dequantize_bf16_gpu, "Q4_0": dequantize_q4_0_gpu, "Q5_0": dequantize_q5_0_gpu, "Q8_0": dequantize_q8_0_gpu, diff --git a/ktransformers/util/utils.py b/ktransformers/util/utils.py index 1c21135..824bd41 100644 --- a/ktransformers/util/utils.py +++ b/ktransformers/util/utils.py @@ -90,7 +90,7 @@ def load_cur_state_dict(module: nn.Module, gguf_loader: GGUFLoader, prefix: str raise Exception(f"can't find {translated_key} in GGUF file!") def load_weights(module:nn.Module, gguf_loader:GGUFLoader, prefix=''): - # print(f"recursively loading weights {prefix},{return_when_injected=}, {only_load_injected=}") + #print(f"recursively loading weights {prefix}") if not isinstance(module, base_operator.BaseInjectedModule): load_cur_state_dict(module, gguf_loader, prefix) for name, child in module._modules.items(): diff --git a/setup.py b/setup.py index ddd4835..345fdb1 100644 --- a/setup.py +++ b/setup.py @@ -1,16 +1,16 @@ #!/usr/bin/env python # coding=utf-8 ''' -Description : +Description : Author : chenxl Date : 2024-07-27 16:15:27 Version : 1.0.0 -LastEditors : chenxl +LastEditors : chenxl LastEditTime : 2024-08-14 16:36:19 Adapted from: https://github.com/Dao-AILab/flash-attention/blob/v2.6.3/setup.py Copyright (c) 2023, Tri Dao. -Copyright (c) 2024 by KVCache.AI, All Rights Reserved. +Copyright (c) 2024 by KVCache.AI, All Rights Reserved. ''' import os @@ -30,6 +30,11 @@ from wheel.bdist_wheel import bdist_wheel as _bdist_wheel from setuptools import setup, Extension from cpufeature.extension import CPUFeature from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME +try: + from torch_musa.utils.simple_porting import SimplePorting + from torch_musa.utils.musa_extension import BuildExtension, MUSAExtension, MUSA_HOME +except ImportError: + MUSA_HOME=None class CpuInstructInfo: CPU_INSTRUCT = os.getenv("CPU_INSTRUCT", "NATIVE") @@ -40,7 +45,7 @@ class CpuInstructInfo: CMAKE_FANCY = "-DLLAMA_NATIVE=OFF -DLLAMA_FMA=ON -DLLAMA_F16C=ON -DLLAMA_AVX=ON -DLLAMA_AVX2=ON -DLLAMA_AVX512=ON -DLLAMA_AVX512_FANCY_SIMD=ON" CMAKE_AVX512 = "-DLLAMA_NATIVE=OFF -DLLAMA_FMA=ON -DLLAMA_F16C=ON -DLLAMA_AVX=ON -DLLAMA_AVX2=ON -DLLAMA_AVX512=ON" CMAKE_AVX2 = "-DLLAMA_NATIVE=OFF -DLLAMA_FMA=ON -DLLAMA_F16C=ON -DLLAMA_AVX=ON -DLLAMA_AVX2=ON" - + class VersionInfo: THIS_DIR = os.path.dirname(os.path.abspath(__file__)) PACKAGE_NAME = "ktransformers" @@ -49,6 +54,16 @@ class VersionInfo: ) FORCE_BUILD = os.getenv("KTRANSFORMERS_FORCE_BUILD", "FALSE") == "TRUE" + def get_musa_bare_metal_version(self, musa_dir): + raw_output = subprocess.run( + [musa_dir + "/bin/mcc", "-v"], check=True, + stdout=subprocess.PIPE, stderr=subprocess.STDOUT).stdout.decode("utf-8") + output = raw_output.split() + release_idx = output.index("version") + 1 + bare_metal_version = parse(output[release_idx].split(",")[0]) + musa_version = f"{bare_metal_version.major}{bare_metal_version.minor}" + return musa_version + def get_cuda_bare_metal_version(self, cuda_dir): raw_output = subprocess.check_output( [cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True) @@ -58,7 +73,7 @@ class VersionInfo: cuda_version = f"{bare_metal_version.major}{bare_metal_version.minor}" return cuda_version - def get_cuda_version_of_torch(self,): + def get_cuda_version_of_torch(self): torch_cuda_version = parse(torch.version.cuda) cuda_version = f"{torch_cuda_version.major}{torch_cuda_version.minor}" return cuda_version @@ -117,7 +132,7 @@ class VersionInfo: torch_version_raw = parse(torch.__version__) torch_version = f"{torch_version_raw.major}{torch_version_raw.minor}" return torch_version - + def get_flash_version(self,): version_file = os.path.join( Path(VersionInfo.THIS_DIR), VersionInfo.PACKAGE_NAME, "__init__.py") @@ -128,12 +143,21 @@ class VersionInfo: return flash_version def get_package_version(self, full_version=False): - flash_version = self.get_flash_version() - package_version = f"{str(flash_version)}+cu{self.get_cuda_bare_metal_version(CUDA_HOME)}torch{self.get_torch_version()}{self.get_cpu_instruct()}" + flash_version = str(self.get_flash_version()) + torch_version = self.get_torch_version() + cpu_instruct = self.get_cpu_instruct() + backend_version = "" + if CUDA_HOME is not None: + backend_version = f"cu{self.get_cuda_bare_metal_version(CUDA_HOME)}" + elif MUSA_HOME is not None: + backend_version = f"mu{self.get_musa_bare_metal_version(MUSA_HOME)}" + else: + raise ValueError("Unsupported backend: CUDA_HOME and MUSA_HOME are not set.") + package_version = f"{flash_version}+{backend_version}torch{torch_version}{cpu_instruct}" if full_version: return package_version if not VersionInfo.FORCE_BUILD: - return str(flash_version) + return flash_version return package_version @@ -218,11 +242,19 @@ class CMakeBuild(BuildExtension): f"-DPYTHON_EXECUTABLE={sys.executable}", f"-DCMAKE_BUILD_TYPE={cfg}", # not used on MSVC, but no harm ] + + if CUDA_HOME is not None: + cmake_args += ["-DKTRANSFORMERS_USE_CUDA=ON"] + elif MUSA_HOME is not None: + cmake_args += ["-DKTRANSFORMERS_USE_MUSA=ON"] + else: + raise ValueError("Unsupported backend: CUDA_HOME and MUSA_HOME are not set.") + build_args = [] if "CMAKE_ARGS" in os.environ: cmake_args += [ item for item in os.environ["CMAKE_ARGS"].split(" ") if item] - + if CpuInstructInfo.CPU_INSTRUCT == CpuInstructInfo.FANCY: cpu_args = CpuInstructInfo.CMAKE_FANCY elif CpuInstructInfo.CPU_INSTRUCT == CpuInstructInfo.AVX512: @@ -231,7 +263,7 @@ class CMakeBuild(BuildExtension): cpu_args = CpuInstructInfo.CMAKE_AVX2 else: cpu_args = CpuInstructInfo.CMAKE_NATIVE - + cmake_args += [ item for item in cpu_args.split(" ") if item ] @@ -276,8 +308,13 @@ class CMakeBuild(BuildExtension): "-DCMAKE_OSX_ARCHITECTURES={}".format(";".join(archs))] if "CMAKE_BUILD_PARALLEL_LEVEL" not in os.environ: + cpu_count = os.cpu_count() + if cpu_count is None: + cpu_count = 1 if hasattr(self, "parallel") and self.parallel: - build_args += [f"-j{self.parallel}"] + build_args += [f"--parallel={self.parallel}"] + else: + build_args += [f"--parallel={cpu_count}"] print("CMake args:", cmake_args) build_temp = Path(ext.sourcedir) / "build" if not build_temp.exists(): @@ -288,28 +325,55 @@ class CMakeBuild(BuildExtension): print("Standard output:", result.stdout) print("Standard error:", result.stderr) subprocess.run( - ["cmake", "--build", ".", *build_args], cwd=build_temp, check=True + ["cmake", "--build", ".", "--verbose", *build_args], cwd=build_temp, check=True ) +if CUDA_HOME is not None: + ops_module = CUDAExtension('KTransformersOps', [ + 'ktransformers/ktransformers_ext/cuda/custom_gguf/dequant.cu', + 'ktransformers/ktransformers_ext/cuda/binding.cpp', + 'ktransformers/ktransformers_ext/cuda/gptq_marlin/gptq_marlin.cu' + ], + extra_compile_args={ + 'cxx': ['-O3', '-DKTRANSFORMERS_USE_CUDA'], + 'nvcc': [ + '-O3', + '--use_fast_math', + '-Xcompiler', '-fPIC', + '-DKTRANSFORMERS_USE_CUDA', + ] + } + ) +elif MUSA_HOME is not None: + SimplePorting(cuda_dir_path="ktransformers/ktransformers_ext/cuda", mapping_rule={ + # Common rules + "at::cuda": "at::musa", + "#include ": "#include \"torch_musa/csrc/aten/musa/MUSAContext.h\"", + "#include ": "#include \"torch_musa/csrc/core/MUSAGuard.h\"", + }).run() + ops_module = MUSAExtension('KTransformersOps', [ + 'ktransformers/ktransformers_ext/cuda_musa/custom_gguf/dequant.mu', + 'ktransformers/ktransformers_ext/cuda_musa/binding.cpp', + # TODO: Add Marlin support for MUSA. + # 'ktransformers/ktransformers_ext/cuda_musa/gptq_marlin/gptq_marlin.mu' + ], + extra_compile_args={ + 'cxx': ['force_mcc'], + 'mcc': [ + '-O3', + '-DKTRANSFORMERS_USE_MUSA', + '-DTHRUST_IGNORE_CUB_VERSION_CHECK', + ] + } + ) +else: + raise ValueError("Unsupported backend: CUDA_HOME and MUSA_HOME are not set.") setup( version=VersionInfo().get_package_version(), cmdclass={"bdist_wheel":BuildWheelsCommand ,"build_ext": CMakeBuild}, ext_modules=[ CMakeExtension("cpuinfer_ext"), - CUDAExtension('KTransformersOps', [ - 'ktransformers/ktransformers_ext/cuda/custom_gguf/dequant.cu', - 'ktransformers/ktransformers_ext/cuda/binding.cpp', - 'ktransformers/ktransformers_ext/cuda/gptq_marlin/gptq_marlin.cu' - ], - extra_compile_args={ - 'cxx': ['-O3'], - 'nvcc': [ - '-O3', - '--use_fast_math', - '-Xcompiler', '-fPIC', - ] - } - ) + ops_module, ] ) diff --git a/test_prompt.txt b/test_prompt.txt index 69fd23b..c749c0e 100644 --- a/test_prompt.txt +++ b/test_prompt.txt @@ -6,4 +6,15 @@ Mr. and Mrs. Dursley, of number four, Privet Drive, were proud to say that they Mr. and Mrs. Dursley, of number four, Privet Drive, were proud to say that they were perfectly normal, thank you very much. They were the last people you'd expect to be involved in anything strange or mysterious, because they just didn't hold with such nonsense. Mr. Dursley was the director of a firm called Grunnings, which made drills. He was a big, beefy man with hardly any neck, although he did have a very large mustache. Mrs. Dursley was thin and blonde and had nearly twice the usual amount of neck, which came in very useful as she spent so much of her time craning over garden fences, spying on the neighbors. The Dursleys had a small son called Dudley and in their opinion there was no finer boy anywhere. The Dursleys had everything they wanted, but they also had a secret, and their greatest fear was that somebody would discover it. They didn't think they could bear it if anyone found out about the Potters. Mrs. Potter was Mrs. Dursley's sister, but they hadn't met for several years; in fact, Mrs. Dursley pretended she didn't have a sister, because her sister and her good-for-nothing husband were as unDursleyish as it was possible to be. The Dursleys shuddered to think what the neighbors would say if the Potters arrived in the street. The Dursleys knew that the Potters had a small son, too, but they had never even seen him. This boy was another good reason for keeping the Potters away; they didn't want Dudley mixing with a child like that.When Mr. and Mrs. Dursley woke up on the dull, gray Tuesday our story starts, there was nothing about the cloudy sky outside to suggest that strange and mysterious things would soon be happening all over the country. Mr. Dursley hummed as he picked out his most boring tie for work, and Mrs. Dursley gossiped away happily as she wrestled a screaming Dudley into his high chair. None of them noticed a large, tawny owl flutter past the window. At half past eight, Mr. Dursley picked up his briefcase, pecked Mrs. Dursley on the cheek, and tried to kiss Dudley good-bye but missed, because Dudley was now having a tantrum and throwing his cereal at the walls. “Little tyke,” chortled Mr. Dursley as he left the house. He got into his car and backed out of number four's drive. It was on the corner of the street that he noticed the first sign of something peculiar — a cat reading a map. For a second, Mr. Dursley didn't realize what he had seen — then he jerked his head around to look again. There was a tabby cat standing on the corner of Privet Drive, but there wasn't a map in sight. What could he have been thinking of? It must have been a trick of the light. Mr. Dursley blinked and stared at the cat. It stared back. As Mr. Dursley drove around the corner and up the road, he watched the cat in his mirror. It was now reading the sign that said Privet Drive — no, looking at the sign; cats couldn't read maps or signs. Mr. Dursley gave himself a little shake and put the cat out of his mind. As he drove toward town he thought of nothing except a large order of drills he was hoping to get that day. But on the edge of town, drills were driven out of his mind by something else. As he sat in the usual morning traffic jam, he couldn't help noticing that there seemed to be a lot of strangely dressed people about. People in cloaks. Mr. Dursley couldn't bear people who dressed in funny clothes — the getups you saw on young people! He supposed this was some stupid new fashion. He drummed his fingers on the steering wheel and his eyes fell on a huddle of these weirdos standing quite close by. They were whispering excitedly together. Mr. Dursley was enraged to see that a couple of them weren't young at all; why, that man had to be older than he was, and wearing an emerald-green cloak! The nerve of him! But then it struck Mr. Dursley that this was probably some silly stunt — these people were obviously collecting for something… yes, that would be it. The traffic moved on and a few minutes later, Mr. Dursley arrived in the Grunnings parking lot, his mind back on drills. Mr. and Mrs. Dursley, of number four, Privet Drive, were proud to say that they were perfectly normal, thank you very much. They were the last people you'd expect to be involved in anything strange or mysterious, because they just didn't hold with such nonsense. Mr. Dursley was the director of a firm called Grunnings, which made drills. He was a big, beefy man with hardly any neck, although he did have a very large mustache. Mrs. Dursley was thin and blonde and had nearly twice the usual amount of neck, which came in very useful as she spent so much of her time craning over garden fences, spying on the neighbors. The Dursleys had a small son called Dudley and in their opinion there was no finer boy anywhere. The Dursleys had everything they wanted, but they also had a secret, and their greatest fear was that somebody would discover it. They didn't think they could bear it if anyone found out about the Potters. Mrs. Potter was Mrs. Dursley's sister, but they hadn't met for several years; in fact, Mrs. Dursley pretended she didn't have a sister, because her sister and her good-for-nothing husband were as unDursleyish as it was possible to be. The Dursleys shuddered to think what the neighbors would say if the Potters arrived in the street. The Dursleys knew that the Potters had a small son, too, but they had never even seen him. This boy was another good reason for keeping the Potters away; they didn't want Dudley mixing with a child like that.When Mr. and Mrs. Dursley woke up on the dull, gray Tuesday our story starts, there was nothing about the cloudy sky outside to suggest that strange and mysterious things would soon be happening all over the country. Mr. Dursley hummed as he picked out his most boring tie for work, and Mrs. Dursley gossiped away happily as she wrestled a screaming Dudley into his high chair. None of them noticed a large, tawny owl flutter past the window. At half past eight, Mr. Dursley picked up his briefcase, pecked Mrs. Dursley on the cheek, and tried to kiss Dudley good-bye but missed, because Dudley was now having a tantrum and throwing his cereal at the walls. “Little tyke,” chortled Mr. Dursley as he left the house. He got into his car and backed out of number four's drive. It was on the corner of the street that he noticed the first sign of something peculiar — a cat reading a map. For a second, Mr. Dursley didn't realize what he had seen — then he jerked his head around to look again. There was a tabby cat standing on the corner of Privet Drive, but there wasn't a map in sight. What could he have been thinking of? It must have been a trick of the light. Mr. Dursley blinked and stared at the cat. It stared back. As Mr. Dursley drove around the corner and up the road, he watched the cat in his mirror. It was now reading the sign that said Privet Drive — no, looking at the sign; cats couldn't read maps or signs. Mr. Dursley gave himself a little shake and put the cat out of his mind. As he drove toward town he thought of nothing except a large order of drills he was hoping to get that day. But on the edge of town, drills were driven out of his mind by something else. As he sat in the usual morning traffic jam, he couldn't help noticing that there seemed to be a lot of strangely dressed people about. People in cloaks. Mr. Dursley couldn't bear people who dressed in funny clothes — the getups you saw on young people! He supposed this was some stupid new fashion. He drummed his fingers on the steering wheel and his eyes fell on a huddle of these weirdos standing quite close by. They were whispering excitedly together. Mr. Dursley was enraged to see that a couple of them weren't young at all; why, that man had to be older than he was, and wearing an emerald-green cloak! The nerve of him! But then it struck Mr. Dursley that this was probably some silly stunt — these people were obviously collecting for something… yes, that would be it. The traffic moved on and a few minutes later, Mr. Dursley arrived in the Grunnings parking lot, his mind back on drills. Mr. and Mrs. Dursley, of number four, Privet Drive, were proud to say that they were perfectly normal, thank you very much. They were the last people you'd expect to be involved in anything strange or mysterious, because they just didn't hold with such nonsense. Mr. Dursley was the director of a firm called Grunnings, which made drills. He was a big, beefy man with hardly any neck, although he did have a very large mustache. Mrs. Dursley was thin and blonde and had nearly twice the usual amount of neck, which came in very useful as she spent so much of her time craning over garden fences, spying on the neighbors. The Dursleys had a small son called Dudley and in their opinion there was no finer boy anywhere. The Dursleys had everything they wanted, but they also had a secret, and their greatest fear was that somebody would discover it. They didn't think they could bear it if anyone found out about the Potters. Mrs. Potter was Mrs. Dursley's sister, but they hadn't met for several years; in fact, Mrs. Dursley pretended she didn't have a sister, because her sister and her good-for-nothing husband were as unDursleyish as it was possible to be. The Dursleys shuddered to think what the neighbors would say if the Potters arrived in the street. The Dursleys knew that the Potters had a small son, too, but they had never even seen him. This boy was another good reason for keeping the Potters away; they didn't want Dudley mixing with a child like that.When Mr. and Mrs. Dursley woke up on the dull, gray Tuesday our story starts, there was nothing about the cloudy sky outside to suggest that strange and mysterious things would soon be happening all over the country. Mr. Dursley hummed as he picked out his most boring tie for work, and Mrs. Dursley gossiped away happily as she wrestled a screaming Dudley into his high chair. None of them noticed a large, tawny owl flutter past the window. At half past eight, Mr. Dursley picked up his briefcase, pecked Mrs. Dursley on the cheek, and tried to kiss Dudley good-bye but missed, because Dudley was now having a tantrum and throwing his cereal at the walls. “Little tyke,” chortled Mr. Dursley as he left the house. He got into his car and backed out of number four's drive. It was on the corner of the street that he noticed the first sign of something peculiar — a cat reading a map. For a second, Mr. Dursley didn't realize what he had seen — then he jerked his head around to look again. There was a tabby cat standing on the corner of Privet Drive, but there wasn't a map in sight. What could he have been thinking of? It must have been a trick of the light. Mr. Dursley blinked and stared at the cat. It stared back. As Mr. Dursley drove around the corner and up the road, he watched the cat in his mirror. It was now reading the sign that said Privet Drive — no, looking at the sign; cats couldn't read maps or signs. Mr. Dursley gave himself a little shake and put the cat out of his mind. As he drove toward town he thought of nothing except a large order of drills he was hoping to get that day. But on the edge of town, drills were driven out of his mind by something else. As he sat in the usual morning traffic jam, he couldn't help noticing that there seemed to be a lot of strangely dressed people about. People in cloaks. Mr. Dursley couldn't bear people who dressed in funny clothes — the getups you saw on young people! He supposed this was some stupid new fashion. He drummed his fingers on the steering wheel and his eyes fell on a huddle of these weirdos standing quite close by. They were whispering excitedly together. Mr. Dursley was enraged to see that a couple of them weren't young at all; why, that man had to be older than he was, and wearing an emerald-green cloak! The nerve of him! But then it struck Mr. Dursley that this was probably some silly stunt — these people were obviously collecting for something… yes, that would be it. The traffic moved on and a few minutes later, Mr. Dursley arrived in the Grunnings parking lot, his mind back on drills. +Mr. and Mrs. Dursley, of number four, Privet Drive, were proud to say that they were perfectly normal, thank you very much. They were the last people you'd expect to be involved in anything strange or mysterious, because they just didn't hold with such nonsense. Mr. Dursley was the director of a firm called Grunnings, which made drills. He was a big, beefy man with hardly any neck, although he did have a very large mustache. Mrs. Dursley was thin and blonde and had nearly twice the usual amount of neck, which came in very useful as she spent so much of her time craning over garden fences, spying on the neighbors. The Dursleys had a small son called Dudley and in their opinion there was no finer boy anywhere. The Dursleys had everything they wanted, but they also had a secret, and their greatest fear was that somebody would discover it. They didn't think they could bear it if anyone found out about the Potters. Mrs. Potter was Mrs. Dursley's sister, but they hadn't met for several years; in fact, Mrs. Dursley pretended she didn't have a sister, because her sister and her good-for-nothing husband were as unDursleyish as it was possible to be. The Dursleys shuddered to think what the neighbors would say if the Potters arrived in the street. The Dursleys knew that the Potters had a small son, too, but they had never even seen him. This boy was another good reason for keeping the Potters away; they didn't want Dudley mixing with a child like that.When Mr. and Mrs. Dursley woke up on the dull, gray Tuesday our story starts, there was nothing about the cloudy sky outside to suggest that strange and mysterious things would soon be happening all over the country. Mr. Dursley hummed as he picked out his most boring tie for work, and Mrs. Dursley gossiped away happily as she wrestled a screaming Dudley into his high chair. None of them noticed a large, tawny owl flutter past the window. At half past eight, Mr. Dursley picked up his briefcase, pecked Mrs. Dursley on the cheek, and tried to kiss Dudley good-bye but missed, because Dudley was now having a tantrum and throwing his cereal at the walls. “Little tyke,” chortled Mr. Dursley as he left the house. He got into his car and backed out of number four's drive. It was on the corner of the street that he noticed the first sign of something peculiar — a cat reading a map. For a second, Mr. Dursley didn't realize what he had seen — then he jerked his head around to look again. There was a tabby cat standing on the corner of Privet Drive, but there wasn't a map in sight. What could he have been thinking of? It must have been a trick of the light. Mr. Dursley blinked and stared at the cat. It stared back. As Mr. Dursley drove around the corner and up the road, he watched the cat in his mirror. It was now reading the sign that said Privet Drive — no, looking at the sign; cats couldn't read maps or signs. Mr. Dursley gave himself a little shake and put the cat out of his mind. As he drove toward town he thought of nothing except a large order of drills he was hoping to get that day. But on the edge of town, drills were driven out of his mind by something else. As he sat in the usual morning traffic jam, he couldn't help noticing that there seemed to be a lot of strangely dressed people about. People in cloaks. Mr. Dursley couldn't bear people who dressed in funny clothes — the getups you saw on young people! He supposed this was some stupid new fashion. He drummed his fingers on the steering wheel and his eyes fell on a huddle of these weirdos standing quite close by. They were whispering excitedly together. Mr. Dursley was enraged to see that a couple of them weren't young at all; why, that man had to be older than he was, and wearing an emerald-green cloak! The nerve of him! But then it struck Mr. Dursley that this was probably some silly stunt — these people were obviously collecting for something… yes, that would be it. The traffic moved on and a few minutes later, Mr. Dursley arrived in the Grunnings parking lot, his mind back on drills. +Mr. and Mrs. Dursley, of number four, Privet Drive, were proud to say that they were perfectly normal, thank you very much. They were the last people you'd expect to be involved in anything strange or mysterious, because they just didn't hold with such nonsense. Mr. Dursley was the director of a firm called Grunnings, which made drills. He was a big, beefy man with hardly any neck, although he did have a very large mustache. Mrs. Dursley was thin and blonde and had nearly twice the usual amount of neck, which came in very useful as she spent so much of her time craning over garden fences, spying on the neighbors. The Dursleys had a small son called Dudley and in their opinion there was no finer boy anywhere. The Dursleys had everything they wanted, but they also had a secret, and their greatest fear was that somebody would discover it. They didn't think they could bear it if anyone found out about the Potters. Mrs. Potter was Mrs. Dursley's sister, but they hadn't met for several years; in fact, Mrs. Dursley pretended she didn't have a sister, because her sister and her good-for-nothing husband were as unDursleyish as it was possible to be. The Dursleys shuddered to think what the neighbors would say if the Potters arrived in the street. The Dursleys knew that the Potters had a small son, too, but they had never even seen him. This boy was another good reason for keeping the Potters away; they didn't want Dudley mixing with a child like that.When Mr. and Mrs. Dursley woke up on the dull, gray Tuesday our story starts, there was nothing about the cloudy sky outside to suggest that strange and mysterious things would soon be happening all over the country. Mr. Dursley hummed as he picked out his most boring tie for work, and Mrs. Dursley gossiped away happily as she wrestled a screaming Dudley into his high chair. None of them noticed a large, tawny owl flutter past the window. At half past eight, Mr. Dursley picked up his briefcase, pecked Mrs. Dursley on the cheek, and tried to kiss Dudley good-bye but missed, because Dudley was now having a tantrum and throwing his cereal at the walls. “Little tyke,” chortled Mr. Dursley as he left the house. He got into his car and backed out of number four's drive. It was on the corner of the street that he noticed the first sign of something peculiar — a cat reading a map. For a second, Mr. Dursley didn't realize what he had seen — then he jerked his head around to look again. There was a tabby cat standing on the corner of Privet Drive, but there wasn't a map in sight. What could he have been thinking of? It must have been a trick of the light. Mr. Dursley blinked and stared at the cat. It stared back. As Mr. Dursley drove around the corner and up the road, he watched the cat in his mirror. It was now reading the sign that said Privet Drive — no, looking at the sign; cats couldn't read maps or signs. Mr. Dursley gave himself a little shake and put the cat out of his mind. As he drove toward town he thought of nothing except a large order of drills he was hoping to get that day. But on the edge of town, drills were driven out of his mind by something else. As he sat in the usual morning traffic jam, he couldn't help noticing that there seemed to be a lot of strangely dressed people about. People in cloaks. Mr. Dursley couldn't bear people who dressed in funny clothes — the getups you saw on young people! He supposed this was some stupid new fashion. He drummed his fingers on the steering wheel and his eyes fell on a huddle of these weirdos standing quite close by. They were whispering excitedly together. Mr. Dursley was enraged to see that a couple of them weren't young at all; why, that man had to be older than he was, and wearing an emerald-green cloak! The nerve of him! But then it struck Mr. Dursley that this was probably some silly stunt — these people were obviously collecting for something… yes, that would be it. The traffic moved on and a few minutes later, Mr. Dursley arrived in the Grunnings parking lot, his mind back on drills. +Mr. and Mrs. Dursley, of number four, Privet Drive, were proud to say that they were perfectly normal, thank you very much. They were the last people you'd expect to be involved in anything strange or mysterious, because they just didn't hold with such nonsense. Mr. Dursley was the director of a firm called Grunnings, which made drills. He was a big, beefy man with hardly any neck, although he did have a very large mustache. Mrs. Dursley was thin and blonde and had nearly twice the usual amount of neck, which came in very useful as she spent so much of her time craning over garden fences, spying on the neighbors. The Dursleys had a small son called Dudley and in their opinion there was no finer boy anywhere. The Dursleys had everything they wanted, but they also had a secret, and their greatest fear was that somebody would discover it. They didn't think they could bear it if anyone found out about the Potters. Mrs. Potter was Mrs. Dursley's sister, but they hadn't met for several years; in fact, Mrs. Dursley pretended she didn't have a sister, because her sister and her good-for-nothing husband were as unDursleyish as it was possible to be. The Dursleys shuddered to think what the neighbors would say if the Potters arrived in the street. The Dursleys knew that the Potters had a small son, too, but they had never even seen him. This boy was another good reason for keeping the Potters away; they didn't want Dudley mixing with a child like that.When Mr. and Mrs. Dursley woke up on the dull, gray Tuesday our story starts, there was nothing about the cloudy sky outside to suggest that strange and mysterious things would soon be happening all over the country. Mr. Dursley hummed as he picked out his most boring tie for work, and Mrs. Dursley gossiped away happily as she wrestled a screaming Dudley into his high chair. None of them noticed a large, tawny owl flutter past the window. At half past eight, Mr. Dursley picked up his briefcase, pecked Mrs. Dursley on the cheek, and tried to kiss Dudley good-bye but missed, because Dudley was now having a tantrum and throwing his cereal at the walls. “Little tyke,” chortled Mr. Dursley as he left the house. He got into his car and backed out of number four's drive. It was on the corner of the street that he noticed the first sign of something peculiar — a cat reading a map. For a second, Mr. Dursley didn't realize what he had seen — then he jerked his head around to look again. There was a tabby cat standing on the corner of Privet Drive, but there wasn't a map in sight. What could he have been thinking of? It must have been a trick of the light. Mr. Dursley blinked and stared at the cat. It stared back. As Mr. Dursley drove around the corner and up the road, he watched the cat in his mirror. It was now reading the sign that said Privet Drive — no, looking at the sign; cats couldn't read maps or signs. Mr. Dursley gave himself a little shake and put the cat out of his mind. As he drove toward town he thought of nothing except a large order of drills he was hoping to get that day. But on the edge of town, drills were driven out of his mind by something else. As he sat in the usual morning traffic jam, he couldn't help noticing that there seemed to be a lot of strangely dressed people about. People in cloaks. Mr. Dursley couldn't bear people who dressed in funny clothes — the getups you saw on young people! He supposed this was some stupid new fashion. He drummed his fingers on the steering wheel and his eyes fell on a huddle of these weirdos standing quite close by. They were whispering excitedly together. Mr. Dursley was enraged to see that a couple of them weren't young at all; why, that man had to be older than he was, and wearing an emerald-green cloak! The nerve of him! But then it struck Mr. Dursley that this was probably some silly stunt — these people were obviously collecting for something… yes, that would be it. The traffic moved on and a few minutes later, Mr. Dursley arrived in the Grunnings parking lot, his mind back on drills. +Mr. and Mrs. Dursley, of number four, Privet Drive, were proud to say that they were perfectly normal, thank you very much. They were the last people you'd expect to be involved in anything strange or mysterious, because they just didn't hold with such nonsense. Mr. Dursley was the director of a firm called Grunnings, which made drills. He was a big, beefy man with hardly any neck, although he did have a very large mustache. Mrs. Dursley was thin and blonde and had nearly twice the usual amount of neck, which came in very useful as she spent so much of her time craning over garden fences, spying on the neighbors. The Dursleys had a small son called Dudley and in their opinion there was no finer boy anywhere. The Dursleys had everything they wanted, but they also had a secret, and their greatest fear was that somebody would discover it. They didn't think they could bear it if anyone found out about the Potters. Mrs. Potter was Mrs. Dursley's sister, but they hadn't met for several years; in fact, Mrs. Dursley pretended she didn't have a sister, because her sister and her good-for-nothing husband were as unDursleyish as it was possible to be. The Dursleys shuddered to think what the neighbors would say if the Potters arrived in the street. The Dursleys knew that the Potters had a small son, too, but they had never even seen him. This boy was another good reason for keeping the Potters away; they didn't want Dudley mixing with a child like that.When Mr. and Mrs. Dursley woke up on the dull, gray Tuesday our story starts, there was nothing about the cloudy sky outside to suggest that strange and mysterious things would soon be happening all over the country. Mr. Dursley hummed as he picked out his most boring tie for work, and Mrs. Dursley gossiped away happily as she wrestled a screaming Dudley into his high chair. None of them noticed a large, tawny owl flutter past the window. At half past eight, Mr. Dursley picked up his briefcase, pecked Mrs. Dursley on the cheek, and tried to kiss Dudley good-bye but missed, because Dudley was now having a tantrum and throwing his cereal at the walls. “Little tyke,” chortled Mr. Dursley as he left the house. He got into his car and backed out of number four's drive. It was on the corner of the street that he noticed the first sign of something peculiar — a cat reading a map. For a second, Mr. Dursley didn't realize what he had seen — then he jerked his head around to look again. There was a tabby cat standing on the corner of Privet Drive, but there wasn't a map in sight. What could he have been thinking of? It must have been a trick of the light. Mr. Dursley blinked and stared at the cat. It stared back. As Mr. Dursley drove around the corner and up the road, he watched the cat in his mirror. It was now reading the sign that said Privet Drive — no, looking at the sign; cats couldn't read maps or signs. Mr. Dursley gave himself a little shake and put the cat out of his mind. As he drove toward town he thought of nothing except a large order of drills he was hoping to get that day. But on the edge of town, drills were driven out of his mind by something else. As he sat in the usual morning traffic jam, he couldn't help noticing that there seemed to be a lot of strangely dressed people about. People in cloaks. Mr. Dursley couldn't bear people who dressed in funny clothes — the getups you saw on young people! He supposed this was some stupid new fashion. He drummed his fingers on the steering wheel and his eyes fell on a huddle of these weirdos standing quite close by. They were whispering excitedly together. Mr. Dursley was enraged to see that a couple of them weren't young at all; why, that man had to be older than he was, and wearing an emerald-green cloak! The nerve of him! But then it struck Mr. Dursley that this was probably some silly stunt — these people were obviously collecting for something… yes, that would be it. The traffic moved on and a few minutes later, Mr. Dursley arrived in the Grunnings parking lot, his mind back on drills. +Mr. and Mrs. Dursley, of number four, Privet Drive, were proud to say that they were perfectly normal, thank you very much. They were the last people you'd expect to be involved in anything strange or mysterious, because they just didn't hold with such nonsense. Mr. Dursley was the director of a firm called Grunnings, which made drills. He was a big, beefy man with hardly any neck, although he did have a very large mustache. Mrs. Dursley was thin and blonde and had nearly twice the usual amount of neck, which came in very useful as she spent so much of her time craning over garden fences, spying on the neighbors. The Dursleys had a small son called Dudley and in their opinion there was no finer boy anywhere. The Dursleys had everything they wanted, but they also had a secret, and their greatest fear was that somebody would discover it. They didn't think they could bear it if anyone found out about the Potters. Mrs. Potter was Mrs. Dursley's sister, but they hadn't met for several years; in fact, Mrs. Dursley pretended she didn't have a sister, because her sister and her good-for-nothing husband were as unDursleyish as it was possible to be. The Dursleys shuddered to think what the neighbors would say if the Potters arrived in the street. The Dursleys knew that the Potters had a small son, too, but they had never even seen him. This boy was another good reason for keeping the Potters away; they didn't want Dudley mixing with a child like that.When Mr. and Mrs. Dursley woke up on the dull, gray Tuesday our story starts, there was nothing about the cloudy sky outside to suggest that strange and mysterious things would soon be happening all over the country. Mr. Dursley hummed as he picked out his most boring tie for work, and Mrs. Dursley gossiped away happily as she wrestled a screaming Dudley into his high chair. None of them noticed a large, tawny owl flutter past the window. At half past eight, Mr. Dursley picked up his briefcase, pecked Mrs. Dursley on the cheek, and tried to kiss Dudley good-bye but missed, because Dudley was now having a tantrum and throwing his cereal at the walls. “Little tyke,” chortled Mr. Dursley as he left the house. He got into his car and backed out of number four's drive. It was on the corner of the street that he noticed the first sign of something peculiar — a cat reading a map. For a second, Mr. Dursley didn't realize what he had seen — then he jerked his head around to look again. There was a tabby cat standing on the corner of Privet Drive, but there wasn't a map in sight. What could he have been thinking of? It must have been a trick of the light. Mr. Dursley blinked and stared at the cat. It stared back. As Mr. Dursley drove around the corner and up the road, he watched the cat in his mirror. It was now reading the sign that said Privet Drive — no, looking at the sign; cats couldn't read maps or signs. Mr. Dursley gave himself a little shake and put the cat out of his mind. As he drove toward town he thought of nothing except a large order of drills he was hoping to get that day. But on the edge of town, drills were driven out of his mind by something else. As he sat in the usual morning traffic jam, he couldn't help noticing that there seemed to be a lot of strangely dressed people about. People in cloaks. Mr. Dursley couldn't bear people who dressed in funny clothes — the getups you saw on young people! He supposed this was some stupid new fashion. He drummed his fingers on the steering wheel and his eyes fell on a huddle of these weirdos standing quite close by. They were whispering excitedly together. Mr. Dursley was enraged to see that a couple of them weren't young at all; why, that man had to be older than he was, and wearing an emerald-green cloak! The nerve of him! But then it struck Mr. Dursley that this was probably some silly stunt — these people were obviously collecting for something… yes, that would be it. The traffic moved on and a few minutes later, Mr. Dursley arrived in the Grunnings parking lot, his mind back on drills. +Mr. and Mrs. Dursley, of number four, Privet Drive, were proud to say that they were perfectly normal, thank you very much. They were the last people you'd expect to be involved in anything strange or mysterious, because they just didn't hold with such nonsense. Mr. Dursley was the director of a firm called Grunnings, which made drills. He was a big, beefy man with hardly any neck, although he did have a very large mustache. Mrs. Dursley was thin and blonde and had nearly twice the usual amount of neck, which came in very useful as she spent so much of her time craning over garden fences, spying on the neighbors. The Dursleys had a small son called Dudley and in their opinion there was no finer boy anywhere. The Dursleys had everything they wanted, but they also had a secret, and their greatest fear was that somebody would discover it. They didn't think they could bear it if anyone found out about the Potters. Mrs. Potter was Mrs. Dursley's sister, but they hadn't met for several years; in fact, Mrs. Dursley pretended she didn't have a sister, because her sister and her good-for-nothing husband were as unDursleyish as it was possible to be. The Dursleys shuddered to think what the neighbors would say if the Potters arrived in the street. The Dursleys knew that the Potters had a small son, too, but they had never even seen him. This boy was another good reason for keeping the Potters away; they didn't want Dudley mixing with a child like that.When Mr. and Mrs. Dursley woke up on the dull, gray Tuesday our story starts, there was nothing about the cloudy sky outside to suggest that strange and mysterious things would soon be happening all over the country. Mr. Dursley hummed as he picked out his most boring tie for work, and Mrs. Dursley gossiped away happily as she wrestled a screaming Dudley into his high chair. None of them noticed a large, tawny owl flutter past the window. At half past eight, Mr. Dursley picked up his briefcase, pecked Mrs. Dursley on the cheek, and tried to kiss Dudley good-bye but missed, because Dudley was now having a tantrum and throwing his cereal at the walls. “Little tyke,” chortled Mr. Dursley as he left the house. He got into his car and backed out of number four's drive. It was on the corner of the street that he noticed the first sign of something peculiar — a cat reading a map. For a second, Mr. Dursley didn't realize what he had seen — then he jerked his head around to look again. There was a tabby cat standing on the corner of Privet Drive, but there wasn't a map in sight. What could he have been thinking of? It must have been a trick of the light. Mr. Dursley blinked and stared at the cat. It stared back. As Mr. Dursley drove around the corner and up the road, he watched the cat in his mirror. It was now reading the sign that said Privet Drive — no, looking at the sign; cats couldn't read maps or signs. Mr. Dursley gave himself a little shake and put the cat out of his mind. As he drove toward town he thought of nothing except a large order of drills he was hoping to get that day. But on the edge of town, drills were driven out of his mind by something else. As he sat in the usual morning traffic jam, he couldn't help noticing that there seemed to be a lot of strangely dressed people about. People in cloaks. Mr. Dursley couldn't bear people who dressed in funny clothes — the getups you saw on young people! He supposed this was some stupid new fashion. He drummed his fingers on the steering wheel and his eyes fell on a huddle of these weirdos standing quite close by. They were whispering excitedly together. Mr. Dursley was enraged to see that a couple of them weren't young at all; why, that man had to be older than he was, and wearing an emerald-green cloak! The nerve of him! But then it struck Mr. Dursley that this was probably some silly stunt — these people were obviously collecting for something… yes, that would be it. The traffic moved on and a few minutes later, Mr. Dursley arrived in the Grunnings parking lot, his mind back on drills. +Mr. and Mrs. Dursley, of number four, Privet Drive, were proud to say that they were perfectly normal, thank you very much. They were the last people you'd expect to be involved in anything strange or mysterious, because they just didn't hold with such nonsense. Mr. Dursley was the director of a firm called Grunnings, which made drills. He was a big, beefy man with hardly any neck, although he did have a very large mustache. Mrs. Dursley was thin and blonde and had nearly twice the usual amount of neck, which came in very useful as she spent so much of her time craning over garden fences, spying on the neighbors. The Dursleys had a small son called Dudley and in their opinion there was no finer boy anywhere. The Dursleys had everything they wanted, but they also had a secret, and their greatest fear was that somebody would discover it. They didn't think they could bear it if anyone found out about the Potters. Mrs. Potter was Mrs. Dursley's sister, but they hadn't met for several years; in fact, Mrs. Dursley pretended she didn't have a sister, because her sister and her good-for-nothing husband were as unDursleyish as it was possible to be. The Dursleys shuddered to think what the neighbors would say if the Potters arrived in the street. The Dursleys knew that the Potters had a small son, too, but they had never even seen him. This boy was another good reason for keeping the Potters away; they didn't want Dudley mixing with a child like that.When Mr. and Mrs. Dursley woke up on the dull, gray Tuesday our story starts, there was nothing about the cloudy sky outside to suggest that strange and mysterious things would soon be happening all over the country. Mr. Dursley hummed as he picked out his most boring tie for work, and Mrs. Dursley gossiped away happily as she wrestled a screaming Dudley into his high chair. None of them noticed a large, tawny owl flutter past the window. At half past eight, Mr. Dursley picked up his briefcase, pecked Mrs. Dursley on the cheek, and tried to kiss Dudley good-bye but missed, because Dudley was now having a tantrum and throwing his cereal at the walls. “Little tyke,” chortled Mr. Dursley as he left the house. He got into his car and backed out of number four's drive. It was on the corner of the street that he noticed the first sign of something peculiar — a cat reading a map. For a second, Mr. Dursley didn't realize what he had seen — then he jerked his head around to look again. There was a tabby cat standing on the corner of Privet Drive, but there wasn't a map in sight. What could he have been thinking of? It must have been a trick of the light. Mr. Dursley blinked and stared at the cat. It stared back. As Mr. Dursley drove around the corner and up the road, he watched the cat in his mirror. It was now reading the sign that said Privet Drive — no, looking at the sign; cats couldn't read maps or signs. Mr. Dursley gave himself a little shake and put the cat out of his mind. As he drove toward town he thought of nothing except a large order of drills he was hoping to get that day. But on the edge of town, drills were driven out of his mind by something else. As he sat in the usual morning traffic jam, he couldn't help noticing that there seemed to be a lot of strangely dressed people about. People in cloaks. Mr. Dursley couldn't bear people who dressed in funny clothes — the getups you saw on young people! He supposed this was some stupid new fashion. He drummed his fingers on the steering wheel and his eyes fell on a huddle of these weirdos standing quite close by. They were whispering excitedly together. Mr. Dursley was enraged to see that a couple of them weren't young at all; why, that man had to be older than he was, and wearing an emerald-green cloak! The nerve of him! But then it struck Mr. Dursley that this was probably some silly stunt — these people were obviously collecting for something… yes, that would be it. The traffic moved on and a few minutes later, Mr. Dursley arrived in the Grunnings parking lot, his mind back on drills. +Mr. and Mrs. Dursley, of number four, Privet Drive, were proud to say that they were perfectly normal, thank you very much. They were the last people you'd expect to be involved in anything strange or mysterious, because they just didn't hold with such nonsense. Mr. Dursley was the director of a firm called Grunnings, which made drills. He was a big, beefy man with hardly any neck, although he did have a very large mustache. Mrs. Dursley was thin and blonde and had nearly twice the usual amount of neck, which came in very useful as she spent so much of her time craning over garden fences, spying on the neighbors. The Dursleys had a small son called Dudley and in their opinion there was no finer boy anywhere. The Dursleys had everything they wanted, but they also had a secret, and their greatest fear was that somebody would discover it. They didn't think they could bear it if anyone found out about the Potters. Mrs. Potter was Mrs. Dursley's sister, but they hadn't met for several years; in fact, Mrs. Dursley pretended she didn't have a sister, because her sister and her good-for-nothing husband were as unDursleyish as it was possible to be. The Dursleys shuddered to think what the neighbors would say if the Potters arrived in the street. The Dursleys knew that the Potters had a small son, too, but they had never even seen him. This boy was another good reason for keeping the Potters away; they didn't want Dudley mixing with a child like that.When Mr. and Mrs. Dursley woke up on the dull, gray Tuesday our story starts, there was nothing about the cloudy sky outside to suggest that strange and mysterious things would soon be happening all over the country. Mr. Dursley hummed as he picked out his most boring tie for work, and Mrs. Dursley gossiped away happily as she wrestled a screaming Dudley into his high chair. None of them noticed a large, tawny owl flutter past the window. At half past eight, Mr. Dursley picked up his briefcase, pecked Mrs. Dursley on the cheek, and tried to kiss Dudley good-bye but missed, because Dudley was now having a tantrum and throwing his cereal at the walls. “Little tyke,” chortled Mr. Dursley as he left the house. He got into his car and backed out of number four's drive. It was on the corner of the street that he noticed the first sign of something peculiar — a cat reading a map. For a second, Mr. Dursley didn't realize what he had seen — then he jerked his head around to look again. There was a tabby cat standing on the corner of Privet Drive, but there wasn't a map in sight. What could he have been thinking of? It must have been a trick of the light. Mr. Dursley blinked and stared at the cat. It stared back. As Mr. Dursley drove around the corner and up the road, he watched the cat in his mirror. It was now reading the sign that said Privet Drive — no, looking at the sign; cats couldn't read maps or signs. Mr. Dursley gave himself a little shake and put the cat out of his mind. As he drove toward town he thought of nothing except a large order of drills he was hoping to get that day. But on the edge of town, drills were driven out of his mind by something else. As he sat in the usual morning traffic jam, he couldn't help noticing that there seemed to be a lot of strangely dressed people about. People in cloaks. Mr. Dursley couldn't bear people who dressed in funny clothes — the getups you saw on young people! He supposed this was some stupid new fashion. He drummed his fingers on the steering wheel and his eyes fell on a huddle of these weirdos standing quite close by. They were whispering excitedly together. Mr. Dursley was enraged to see that a couple of them weren't young at all; why, that man had to be older than he was, and wearing an emerald-green cloak! The nerve of him! But then it struck Mr. Dursley that this was probably some silly stunt — these people were obviously collecting for something… yes, that would be it. The traffic moved on and a few minutes later, Mr. Dursley arrived in the Grunnings parking lot, his mind back on drills. +Mr. and Mrs. Dursley, of number four, Privet Drive, were proud to say that they were perfectly normal, thank you very much. They were the last people you'd expect to be involved in anything strange or mysterious, because they just didn't hold with such nonsense. Mr. Dursley was the director of a firm called Grunnings, which made drills. He was a big, beefy man with hardly any neck, although he did have a very large mustache. Mrs. Dursley was thin and blonde and had nearly twice the usual amount of neck, which came in very useful as she spent so much of her time craning over garden fences, spying on the neighbors. The Dursleys had a small son called Dudley and in their opinion there was no finer boy anywhere. The Dursleys had everything they wanted, but they also had a secret, and their greatest fear was that somebody would discover it. They didn't think they could bear it if anyone found out about the Potters. Mrs. Potter was Mrs. Dursley's sister, but they hadn't met for several years; in fact, Mrs. Dursley pretended she didn't have a sister, because her sister and her good-for-nothing husband were as unDursleyish as it was possible to be. The Dursleys shuddered to think what the neighbors would say if the Potters arrived in the street. The Dursleys knew that the Potters had a small son, too, but they had never even seen him. This boy was another good reason for keeping the Potters away; they didn't want Dudley mixing with a child like that.When Mr. and Mrs. Dursley woke up on the dull, gray Tuesday our story starts, there was nothing about the cloudy sky outside to suggest that strange and mysterious things would soon be happening all over the country. Mr. Dursley hummed as he picked out his most boring tie for work, and Mrs. Dursley gossiped away happily as she wrestled a screaming Dudley into his high chair. None of them noticed a large, tawny owl flutter past the window. At half past eight, Mr. Dursley picked up his briefcase, pecked Mrs. Dursley on the cheek, and tried to kiss Dudley good-bye but missed, because Dudley was now having a tantrum and throwing his cereal at the walls. “Little tyke,” chortled Mr. Dursley as he left the house. He got into his car and backed out of number four's drive. It was on the corner of the street that he noticed the first sign of something peculiar — a cat reading a map. For a second, Mr. Dursley didn't realize what he had seen — then he jerked his head around to look again. There was a tabby cat standing on the corner of Privet Drive, but there wasn't a map in sight. What could he have been thinking of? It must have been a trick of the light. Mr. Dursley blinked and stared at the cat. It stared back. As Mr. Dursley drove around the corner and up the road, he watched the cat in his mirror. It was now reading the sign that said Privet Drive — no, looking at the sign; cats couldn't read maps or signs. Mr. Dursley gave himself a little shake and put the cat out of his mind. As he drove toward town he thought of nothing except a large order of drills he was hoping to get that day. But on the edge of town, drills were driven out of his mind by something else. As he sat in the usual morning traffic jam, he couldn't help noticing that there seemed to be a lot of strangely dressed people about. People in cloaks. Mr. Dursley couldn't bear people who dressed in funny clothes — the getups you saw on young people! He supposed this was some stupid new fashion. He drummed his fingers on the steering wheel and his eyes fell on a huddle of these weirdos standing quite close by. They were whispering excitedly together. Mr. Dursley was enraged to see that a couple of them weren't young at all; why, that man had to be older than he was, and wearing an emerald-green cloak! The nerve of him! But then it struck Mr. Dursley that this was probably some silly stunt — these people were obviously collecting for something… yes, that would be it. The traffic moved on and a few minutes later, Mr. Dursley arrived in the Grunnings parking lot, his mind back on drills. +Mr. and Mrs. Dursley, of number four, Privet Drive, were proud to say that they were perfectly normal, thank you very much. They were the last people you'd expect to be involved in anything strange or mysterious, because they just didn't hold with such nonsense. Mr. Dursley was the director of a firm called Grunnings, which made drills. He was a big, beefy man with hardly any neck, although he did have a very large mustache. Mrs. Dursley was thin and blonde and had nearly twice the usual amount of neck, which came in very useful as she spent so much of her time craning over garden fences, spying on the neighbors. The Dursleys had a small son called Dudley and in their opinion there was no finer boy anywhere. The Dursleys had everything they wanted, but they also had a secret, and their greatest fear was that somebody would discover it. They didn't think they could bear it if anyone found out about the Potters. Mrs. Potter was Mrs. Dursley's sister, but they hadn't met for several years; in fact, Mrs. Dursley pretended she didn't have a sister, because her sister and her good-for-nothing husband were as unDursleyish as it was possible to be. The Dursleys shuddered to think what the neighbors would say if the Potters arrived in the street. The Dursleys knew that the Potters had a small son, too, but they had never even seen him. This boy was another good reason for keeping the Potters away; they didn't want Dudley mixing with a child like that.When Mr. and Mrs. Dursley woke up on the dull, gray Tuesday our story starts, there was nothing about the cloudy sky outside to suggest that strange and mysterious things would soon be happening all over the country. Mr. Dursley hummed as he picked out his most boring tie for work, and Mrs. Dursley gossiped away happily as she wrestled a screaming Dudley into his high chair. None of them noticed a large, tawny owl flutter past the window. At half past eight, Mr. Dursley picked up his briefcase, pecked Mrs. Dursley on the cheek, and tried to kiss Dudley good-bye but missed, because Dudley was now having a tantrum and throwing his cereal at the walls. “Little tyke,” chortled Mr. Dursley as he left the house. He got into his car and backed out of number four's drive. It was on the corner of the street that he noticed the first sign of something peculiar — a cat reading a map. For a second, Mr. Dursley didn't realize what he had seen — then he jerked his head around to look again. There was a tabby cat standing on the corner of Privet Drive, but there wasn't a map in sight. What could he have been thinking of? It must have been a trick of the light. Mr. Dursley blinked and stared at the cat. It stared back. As Mr. Dursley drove around the corner and up the road, he watched the cat in his mirror. It was now reading the sign that said Privet Drive — no, looking at the sign; cats couldn't read maps or signs. Mr. Dursley gave himself a little shake and put the cat out of his mind. As he drove toward town he thought of nothing except a large order of drills he was hoping to get that day. But on the edge of town, drills were driven out of his mind by something else. As he sat in the usual morning traffic jam, he couldn't help noticing that there seemed to be a lot of strangely dressed people about. People in cloaks. Mr. Dursley couldn't bear people who dressed in funny clothes — the getups you saw on young people! He supposed this was some stupid new fashion. He drummed his fingers on the steering wheel and his eyes fell on a huddle of these weirdos standing quite close by. They were whispering excitedly together. Mr. Dursley was enraged to see that a couple of them weren't young at all; why, that man had to be older than he was, and wearing an emerald-green cloak! The nerve of him! But then it struck Mr. Dursley that this was probably some silly stunt — these people were obviously collecting for something… yes, that would be it. The traffic moved on and a few minutes later, Mr. Dursley arrived in the Grunnings parking lot, his mind back on drills. +Mr. and Mrs. Dursley, of number four, Privet Drive, were proud to say that they were perfectly normal, thank you very much. They were the last people you'd expect to be involved in anything strange or mysterious, because they just didn't hold with such nonsense. Mr. Dursley was the director of a firm called Grunnings, which made drills. He was a big, beefy man with hardly any neck, although he did have a very large mustache. Mrs. Dursley was thin and blonde and had nearly twice the usual amount of neck, which came in very useful as she spent so much of her time craning over garden fences, spying on the neighbors. The Dursleys had a small son called Dudley and in their opinion there was no finer boy anywhere. The Dursleys had everything they wanted, but they also had a secret, and their greatest fear was that somebody would discover it. They didn't think they could bear it if anyone found out about the Potters. Mrs. Potter was Mrs. Dursley's sister, but they hadn't met for several years; in fact, Mrs. Dursley pretended she didn't have a sister, because her sister and her good-for-nothing husband were as unDursleyish as it was possible to be. The Dursleys shuddered to think what the neighbors would say if the Potters arrived in the street. The Dursleys knew that the Potters had a small son, too, but they had never even seen him. This boy was another good reason for keeping the Potters away; they didn't want Dudley mixing with a child like that.When Mr. and Mrs. Dursley woke up on the dull, gray Tuesday our story starts, there was nothing about the cloudy sky outside to suggest that strange and mysterious things would soon be happening all over the country. Mr. Dursley hummed as he picked out his most boring tie for work, and Mrs. Dursley gossiped away happily as she wrestled a screaming Dudley into his high chair. None of them noticed a large, tawny owl flutter past the window. At half past eight, Mr. Dursley picked up his briefcase, pecked Mrs. Dursley on the cheek, and tried to kiss Dudley good-bye but missed, because Dudley was now having a tantrum and throwing his cereal at the walls. “Little tyke,” chortled Mr. Dursley as he left the house. He got into his car and backed out of number four's drive. It was on the corner of the street that he noticed the first sign of something peculiar — a cat reading a map. For a second, Mr. Dursley didn't realize what he had seen — then he jerked his head around to look again. There was a tabby cat standing on the corner of Privet Drive, but there wasn't a map in sight. What could he have been thinking of? It must have been a trick of the light. Mr. Dursley blinked and stared at the cat. It stared back. As Mr. Dursley drove around the corner and up the road, he watched the cat in his mirror. It was now reading the sign that said Privet Drive — no, looking at the sign; cats couldn't read maps or signs. Mr. Dursley gave himself a little shake and put the cat out of his mind. As he drove toward town he thought of nothing except a large order of drills he was hoping to get that day. But on the edge of town, drills were driven out of his mind by something else. As he sat in the usual morning traffic jam, he couldn't help noticing that there seemed to be a lot of strangely dressed people about. People in cloaks. Mr. Dursley couldn't bear people who dressed in funny clothes — the getups you saw on young people! He supposed this was some stupid new fashion. He drummed his fingers on the steering wheel and his eyes fell on a huddle of these weirdos standing quite close by. They were whispering excitedly together. Mr. Dursley was enraged to see that a couple of them weren't young at all; why, that man had to be older than he was, and wearing an emerald-green cloak! The nerve of him! But then it struck Mr. Dursley that this was probably some silly stunt — these people were obviously collecting for something… yes, that would be it. The traffic moved on and a few minutes later, Mr. Dursley arrived in the Grunnings parking lot, his mind back on drills. 阅读以上文字,并概括大意