diff --git a/kt-kernel/cpu_backend/task_queue.cpp b/kt-kernel/cpu_backend/task_queue.cpp index aee5fc14..56509298 100644 --- a/kt-kernel/cpu_backend/task_queue.cpp +++ b/kt-kernel/cpu_backend/task_queue.cpp @@ -24,7 +24,11 @@ TaskQueue::TaskQueue() : done(false), pending(0) { } TaskQueue::~TaskQueue() { - done.store(true, std::memory_order_release); + { + std::lock_guard lock(mtx); + done.store(true, std::memory_order_release); + } + cv.notify_all(); if (workerThread.joinable()) workerThread.join(); Node* node = head.load(std::memory_order_relaxed); @@ -40,11 +44,18 @@ void TaskQueue::enqueue(std::function task) { Node* node = new Node(task); Node* prev = tail.exchange(node, std::memory_order_acq_rel); prev->next.store(node, std::memory_order_release); + { + std::lock_guard lock(mtx); + } + cv.notify_one(); } void TaskQueue::sync(size_t allow_n_pending) { - // Spin until the pending task count drops to the allowed threshold. - while (pending.load(std::memory_order_acquire) > allow_n_pending); + std::unique_lock lock(mtx); + cv.wait(lock, [&] { + return pending.load(std::memory_order_acquire) <= allow_n_pending + || done.load(std::memory_order_acquire); + }); } void TaskQueue::worker() { @@ -58,7 +69,17 @@ void TaskQueue::worker() { delete curr; curr = next; head.store(curr, std::memory_order_release); - pending.fetch_sub(1, std::memory_order_acq_rel); + { + std::lock_guard lock(mtx); + pending.fetch_sub(1, std::memory_order_acq_rel); + } + cv.notify_all(); + } else { + std::unique_lock lock(mtx); + cv.wait(lock, [&] { + return curr->next.load(std::memory_order_acquire) != nullptr + || done.load(std::memory_order_acquire); + }); } } } \ No newline at end of file diff --git a/kt-kernel/cpu_backend/task_queue.h b/kt-kernel/cpu_backend/task_queue.h index 6899c172..e4f21f97 100644 --- a/kt-kernel/cpu_backend/task_queue.h +++ b/kt-kernel/cpu_backend/task_queue.h @@ -40,6 +40,8 @@ class TaskQueue { std::atomic done; std::atomic pending; std::thread workerThread; + std::mutex mtx; + std::condition_variable cv; void worker(); };