Merge pull request #99 from chenht2022/main

Adapt Windows
This commit is contained in:
Chen Hongtao 2024-10-09 19:09:58 +08:00 committed by GitHub
commit 43fc7f44a6
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 102 additions and 38 deletions

View file

@ -4,7 +4,7 @@
* @Date : 2024-07-17 12:25:51 * @Date : 2024-07-17 12:25:51
* @Version : 1.0.0 * @Version : 1.0.0
* @LastEditors : chenht2022 * @LastEditors : chenht2022
* @LastEditTime : 2024-07-25 10:33:44 * @LastEditTime : 2024-10-09 11:08:10
* @Copyright (c) 2024 by KVCache.AI, All Rights Reserved. * @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
**/ **/
#include "task_queue.h" #include "task_queue.h"
@ -17,8 +17,9 @@ TaskQueue::TaskQueue() {
TaskQueue::~TaskQueue() { TaskQueue::~TaskQueue() {
{ {
std::unique_lock<std::mutex> lock(mutex); mutex.lock();
exit_flag.store(true, std::memory_order_seq_cst); exit_flag.store(true, std::memory_order_seq_cst);
mutex.unlock();
} }
cv.notify_all(); cv.notify_all();
if (worker.joinable()) { if (worker.joinable()) {
@ -28,9 +29,10 @@ TaskQueue::~TaskQueue() {
void TaskQueue::enqueue(std::function<void()> task) { void TaskQueue::enqueue(std::function<void()> task) {
{ {
std::unique_lock<std::mutex> lock(mutex); mutex.lock();
tasks.push(task); tasks.push(task);
sync_flag.store(false, std::memory_order_seq_cst); sync_flag.store(false, std::memory_order_seq_cst);
mutex.unlock();
} }
cv.notify_one(); cv.notify_one();
} }
@ -44,20 +46,22 @@ void TaskQueue::processTasks() {
while (true) { while (true) {
std::function<void()> task; std::function<void()> task;
{ {
std::unique_lock<std::mutex> lock(mutex); mutex.lock();
cv.wait(lock, [this]() { return !tasks.empty() || exit_flag.load(std::memory_order_seq_cst); }); cv.wait(mutex, [this]() { return !tasks.empty() || exit_flag.load(std::memory_order_seq_cst); });
if (exit_flag.load(std::memory_order_seq_cst) && tasks.empty()) { if (exit_flag.load(std::memory_order_seq_cst) && tasks.empty()) {
return; return;
} }
task = tasks.front(); task = tasks.front();
tasks.pop(); tasks.pop();
mutex.unlock();
} }
task(); task();
{ {
std::lock_guard<std::mutex> lock(mutex); mutex.lock();
if (tasks.empty()) { if (tasks.empty()) {
sync_flag.store(true, std::memory_order_seq_cst); sync_flag.store(true, std::memory_order_seq_cst);
} }
mutex.unlock();
} }
} }
} }

View file

@ -3,8 +3,8 @@
* @Author : chenht2022 * @Author : chenht2022
* @Date : 2024-07-16 10:43:18 * @Date : 2024-07-16 10:43:18
* @Version : 1.0.0 * @Version : 1.0.0
* @LastEditors : chenxl * @LastEditors : chenht
* @LastEditTime : 2024-08-08 04:23:51 * @LastEditTime : 2024-10-09 11:08:07
* @Copyright (c) 2024 by KVCache.AI, All Rights Reserved. * @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
**/ **/
#ifndef CPUINFER_TASKQUEUE_H #ifndef CPUINFER_TASKQUEUE_H
@ -24,34 +24,94 @@
class custom_mutex { class custom_mutex {
private: private:
#ifdef _WIN32 #ifdef _WIN32
HANDLE global_mutex; CRITICAL_SECTION cs;
#else #else
std::mutex global_mutex; std::mutex mtx;
#endif #endif
public: public:
custom_mutex() custom_mutex() {
{
#ifdef _WIN32 #ifdef _WIN32
HANDLE global_mutex; InitializeCriticalSection(&cs);
#else
// No initialization required for std::mutex
#endif #endif
} }
void lock() ~custom_mutex() {
{
#ifdef _WIN32 #ifdef _WIN32
WaitForSingleObject(global_mutex, INFINITE); DeleteCriticalSection(&cs);
#else
global_mutex.lock();
#endif #endif
} }
void unlock() void lock() {
{
#ifdef _WIN32 #ifdef _WIN32
ReleaseMutex(global_mutex); EnterCriticalSection(&cs);
#else #else
global_mutex.unlock(); mtx.lock();
#endif
}
void unlock() {
#ifdef _WIN32
LeaveCriticalSection(&cs);
#else
mtx.unlock();
#endif
}
#ifdef _WIN32
CRITICAL_SECTION* get_handle() {
return &cs;
}
#else
std::mutex* get_handle() {
return &mtx;
}
#endif
};
class custom_condition_variable {
private:
#ifdef _WIN32
CONDITION_VARIABLE cond_var;
#else
std::condition_variable cond_var;
#endif
public:
custom_condition_variable() {
#ifdef _WIN32
InitializeConditionVariable(&cond_var);
#endif
}
template <typename Predicate>
void wait(custom_mutex& mutex, Predicate pred) {
#ifdef _WIN32
while (!pred()) {
SleepConditionVariableCS(&cond_var, mutex.get_handle(), INFINITE);
}
#else
std::unique_lock<std::mutex> lock(*mutex.get_handle(), std::adopt_lock);
cond_var.wait(lock, pred);
lock.release();
#endif
}
void notify_one() {
#ifdef _WIN32
WakeConditionVariable(&cond_var);
#else
cond_var.notify_one();
#endif
}
void notify_all() {
#ifdef _WIN32
WakeAllConditionVariable(&cond_var);
#else
cond_var.notify_all();
#endif #endif
} }
}; };
@ -69,8 +129,8 @@ class TaskQueue {
void processTasks(); void processTasks();
std::queue<std::function<void()>> tasks; std::queue<std::function<void()>> tasks;
std::mutex mutex; custom_mutex mutex;
std::condition_variable cv; custom_condition_variable cv;
std::thread worker; std::thread worker;
std::atomic<bool> sync_flag; std::atomic<bool> sync_flag;
std::atomic<bool> exit_flag; std::atomic<bool> exit_flag;