Merge pull request #99 from chenht2022/main

Adapt Windows
This commit is contained in:
Chen Hongtao 2024-10-09 19:09:58 +08:00 committed by GitHub
commit 43fc7f44a6
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 102 additions and 38 deletions

View file

@ -1,10 +1,10 @@
/**
* @Description :
* @Author : chenht2022
* @Date : 2024-07-17 12:25:51
* @Version : 1.0.0
* @LastEditors : chenht2022
* @LastEditTime : 2024-07-25 10:33:44
* @Description :
* @Author : chenht2022
* @Date : 2024-07-17 12:25:51
* @Version : 1.0.0
* @LastEditors : chenht2022
* @LastEditTime : 2024-10-09 11:08:10
* @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
**/
#include "task_queue.h"
@ -17,8 +17,9 @@ TaskQueue::TaskQueue() {
TaskQueue::~TaskQueue() {
{
std::unique_lock<std::mutex> lock(mutex);
mutex.lock();
exit_flag.store(true, std::memory_order_seq_cst);
mutex.unlock();
}
cv.notify_all();
if (worker.joinable()) {
@ -28,9 +29,10 @@ TaskQueue::~TaskQueue() {
void TaskQueue::enqueue(std::function<void()> task) {
{
std::unique_lock<std::mutex> lock(mutex);
mutex.lock();
tasks.push(task);
sync_flag.store(false, std::memory_order_seq_cst);
mutex.unlock();
}
cv.notify_one();
}
@ -44,20 +46,22 @@ void TaskQueue::processTasks() {
while (true) {
std::function<void()> task;
{
std::unique_lock<std::mutex> lock(mutex);
cv.wait(lock, [this]() { return !tasks.empty() || exit_flag.load(std::memory_order_seq_cst); });
mutex.lock();
cv.wait(mutex, [this]() { return !tasks.empty() || exit_flag.load(std::memory_order_seq_cst); });
if (exit_flag.load(std::memory_order_seq_cst) && tasks.empty()) {
return;
}
task = tasks.front();
tasks.pop();
mutex.unlock();
}
task();
{
std::lock_guard<std::mutex> lock(mutex);
mutex.lock();
if (tasks.empty()) {
sync_flag.store(true, std::memory_order_seq_cst);
}
mutex.unlock();
}
}
}
}

View file

@ -1,10 +1,10 @@
/**
* @Description :
* @Author : chenht2022
* @Date : 2024-07-16 10:43:18
* @Version : 1.0.0
* @LastEditors : chenxl
* @LastEditTime : 2024-08-08 04:23:51
* @Description :
* @Author : chenht2022
* @Date : 2024-07-16 10:43:18
* @Version : 1.0.0
* @LastEditors : chenht
* @LastEditTime : 2024-10-09 11:08:07
* @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
**/
#ifndef CPUINFER_TASKQUEUE_H
@ -22,36 +22,96 @@
#endif
class custom_mutex {
private:
private:
#ifdef _WIN32
HANDLE global_mutex;
CRITICAL_SECTION cs;
#else
std::mutex global_mutex;
std::mutex mtx;
#endif
public:
custom_mutex()
{
public:
custom_mutex() {
#ifdef _WIN32
HANDLE global_mutex;
InitializeCriticalSection(&cs);
#else
// No initialization required for std::mutex
#endif
}
void lock()
{
~custom_mutex() {
#ifdef _WIN32
WaitForSingleObject(global_mutex, INFINITE);
#else
global_mutex.lock();
DeleteCriticalSection(&cs);
#endif
}
void unlock()
{
void lock() {
#ifdef _WIN32
ReleaseMutex(global_mutex);
EnterCriticalSection(&cs);
#else
global_mutex.unlock();
mtx.lock();
#endif
}
void unlock() {
#ifdef _WIN32
LeaveCriticalSection(&cs);
#else
mtx.unlock();
#endif
}
#ifdef _WIN32
CRITICAL_SECTION* get_handle() {
return &cs;
}
#else
std::mutex* get_handle() {
return &mtx;
}
#endif
};
class custom_condition_variable {
private:
#ifdef _WIN32
CONDITION_VARIABLE cond_var;
#else
std::condition_variable cond_var;
#endif
public:
custom_condition_variable() {
#ifdef _WIN32
InitializeConditionVariable(&cond_var);
#endif
}
template <typename Predicate>
void wait(custom_mutex& mutex, Predicate pred) {
#ifdef _WIN32
while (!pred()) {
SleepConditionVariableCS(&cond_var, mutex.get_handle(), INFINITE);
}
#else
std::unique_lock<std::mutex> lock(*mutex.get_handle(), std::adopt_lock);
cond_var.wait(lock, pred);
lock.release();
#endif
}
void notify_one() {
#ifdef _WIN32
WakeConditionVariable(&cond_var);
#else
cond_var.notify_one();
#endif
}
void notify_all() {
#ifdef _WIN32
WakeAllConditionVariable(&cond_var);
#else
cond_var.notify_all();
#endif
}
};
@ -69,10 +129,10 @@ class TaskQueue {
void processTasks();
std::queue<std::function<void()>> tasks;
std::mutex mutex;
std::condition_variable cv;
custom_mutex mutex;
custom_condition_variable cv;
std::thread worker;
std::atomic<bool> sync_flag;
std::atomic<bool> exit_flag;
};
#endif
#endif