diff --git a/ktransformers/ktransformers_ext/cpu_backend/task_queue.cpp b/ktransformers/ktransformers_ext/cpu_backend/task_queue.cpp index 5d20d1e..297a4ea 100644 --- a/ktransformers/ktransformers_ext/cpu_backend/task_queue.cpp +++ b/ktransformers/ktransformers_ext/cpu_backend/task_queue.cpp @@ -1,10 +1,10 @@ /** - * @Description : - * @Author : chenht2022 - * @Date : 2024-07-17 12:25:51 - * @Version : 1.0.0 - * @LastEditors : chenht2022 - * @LastEditTime : 2024-07-25 10:33:44 + * @Description : + * @Author : chenht2022 + * @Date : 2024-07-17 12:25:51 + * @Version : 1.0.0 + * @LastEditors : chenht2022 + * @LastEditTime : 2024-10-09 11:08:10 * @Copyright (c) 2024 by KVCache.AI, All Rights Reserved. **/ #include "task_queue.h" @@ -17,8 +17,9 @@ TaskQueue::TaskQueue() { TaskQueue::~TaskQueue() { { - std::unique_lock lock(mutex); + mutex.lock(); exit_flag.store(true, std::memory_order_seq_cst); + mutex.unlock(); } cv.notify_all(); if (worker.joinable()) { @@ -28,9 +29,10 @@ TaskQueue::~TaskQueue() { void TaskQueue::enqueue(std::function task) { { - std::unique_lock lock(mutex); + mutex.lock(); tasks.push(task); sync_flag.store(false, std::memory_order_seq_cst); + mutex.unlock(); } cv.notify_one(); } @@ -44,20 +46,22 @@ void TaskQueue::processTasks() { while (true) { std::function task; { - std::unique_lock lock(mutex); - cv.wait(lock, [this]() { return !tasks.empty() || exit_flag.load(std::memory_order_seq_cst); }); + mutex.lock(); + cv.wait(mutex, [this]() { return !tasks.empty() || exit_flag.load(std::memory_order_seq_cst); }); if (exit_flag.load(std::memory_order_seq_cst) && tasks.empty()) { return; } task = tasks.front(); tasks.pop(); + mutex.unlock(); } task(); { - std::lock_guard lock(mutex); + mutex.lock(); if (tasks.empty()) { sync_flag.store(true, std::memory_order_seq_cst); } + mutex.unlock(); } } -} +} \ No newline at end of file diff --git a/ktransformers/ktransformers_ext/cpu_backend/task_queue.h b/ktransformers/ktransformers_ext/cpu_backend/task_queue.h index 5325dcc..4f9f112 100644 --- a/ktransformers/ktransformers_ext/cpu_backend/task_queue.h +++ b/ktransformers/ktransformers_ext/cpu_backend/task_queue.h @@ -1,10 +1,10 @@ /** - * @Description : - * @Author : chenht2022 - * @Date : 2024-07-16 10:43:18 - * @Version : 1.0.0 - * @LastEditors : chenxl - * @LastEditTime : 2024-08-08 04:23:51 + * @Description : + * @Author : chenht2022 + * @Date : 2024-07-16 10:43:18 + * @Version : 1.0.0 + * @LastEditors : chenht + * @LastEditTime : 2024-10-09 11:08:07 * @Copyright (c) 2024 by KVCache.AI, All Rights Reserved. **/ #ifndef CPUINFER_TASKQUEUE_H @@ -22,36 +22,96 @@ #endif class custom_mutex { -private: + private: #ifdef _WIN32 - HANDLE global_mutex; + CRITICAL_SECTION cs; #else - std::mutex global_mutex; + std::mutex mtx; #endif - -public: - custom_mutex() - { + + public: + custom_mutex() { #ifdef _WIN32 - HANDLE global_mutex; + InitializeCriticalSection(&cs); +#else + // No initialization required for std::mutex #endif } - void lock() - { + ~custom_mutex() { #ifdef _WIN32 - WaitForSingleObject(global_mutex, INFINITE); -#else - global_mutex.lock(); + DeleteCriticalSection(&cs); #endif } - void unlock() - { + void lock() { #ifdef _WIN32 - ReleaseMutex(global_mutex); + EnterCriticalSection(&cs); #else - global_mutex.unlock(); + mtx.lock(); +#endif + } + + void unlock() { +#ifdef _WIN32 + LeaveCriticalSection(&cs); +#else + mtx.unlock(); +#endif + } + +#ifdef _WIN32 + CRITICAL_SECTION* get_handle() { + return &cs; + } +#else + std::mutex* get_handle() { + return &mtx; + } +#endif +}; + +class custom_condition_variable { + private: +#ifdef _WIN32 + CONDITION_VARIABLE cond_var; +#else + std::condition_variable cond_var; +#endif + + public: + custom_condition_variable() { +#ifdef _WIN32 + InitializeConditionVariable(&cond_var); +#endif + } + + template + void wait(custom_mutex& mutex, Predicate pred) { +#ifdef _WIN32 + while (!pred()) { + SleepConditionVariableCS(&cond_var, mutex.get_handle(), INFINITE); + } +#else + std::unique_lock lock(*mutex.get_handle(), std::adopt_lock); + cond_var.wait(lock, pred); + lock.release(); +#endif + } + + void notify_one() { +#ifdef _WIN32 + WakeConditionVariable(&cond_var); +#else + cond_var.notify_one(); +#endif + } + + void notify_all() { +#ifdef _WIN32 + WakeAllConditionVariable(&cond_var); +#else + cond_var.notify_all(); #endif } }; @@ -69,10 +129,10 @@ class TaskQueue { void processTasks(); std::queue> tasks; - std::mutex mutex; - std::condition_variable cv; + custom_mutex mutex; + custom_condition_variable cv; std::thread worker; std::atomic sync_flag; std::atomic exit_flag; }; -#endif +#endif \ No newline at end of file