Document & wait queue condition in threadpool destructor.
This commit is contained in:
@@ -3,7 +3,6 @@
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
#include <atomic>
|
||||
#include <utility>
|
||||
|
||||
namespace MultiThreading {
|
||||
class TaskBase;
|
||||
@@ -29,7 +28,7 @@ private:
|
||||
T* result = nullptr;
|
||||
std::function<T()> callable = nullptr;
|
||||
private:
|
||||
explicit Task(std::function<T()> callable, T* result = nullptr) : TaskBase(), result(result), callable(callable) {}
|
||||
explicit Task(std::function<T()> callable, T* result = nullptr) : TaskBase(), result(result), callable(std::move(callable)) {}
|
||||
public:
|
||||
void Run() final { result ? *result = callable() : callable(); complete = true; }
|
||||
public:
|
||||
@@ -46,7 +45,6 @@ template <>
|
||||
class MultiThreading::Task<void> : public TaskBase {
|
||||
private:
|
||||
std::function<void()> callable = nullptr;
|
||||
private:
|
||||
explicit Task(std::function<void()> callable) : TaskBase(), callable(std::move(callable)) {}
|
||||
public:
|
||||
void Run() final { callable(); complete = true; }
|
||||
@@ -55,4 +53,5 @@ public:
|
||||
/// @param callable The function to run, *usually a lambda or std::bind*
|
||||
/// @note this is shared_ptr so you don't have to delete it.
|
||||
static std::shared_ptr<Task<void>> Create(const std::function<void()>& callable) { return std::shared_ptr<Task<void>>(new Task<void>(callable)); }
|
||||
~Task() = default;
|
||||
};
|
@@ -1,11 +1,9 @@
|
||||
#pragma once
|
||||
|
||||
#include <mutex>
|
||||
#include <thread>
|
||||
#include <functional>
|
||||
#include <stdexcept>
|
||||
#include <condition_variable>
|
||||
#include <mutex>
|
||||
#include <utility>
|
||||
#include <MultiThreading/Task.h>
|
||||
|
||||
namespace MultiThreading {
|
||||
@@ -23,13 +21,31 @@ private:
|
||||
private:
|
||||
void Runner();
|
||||
public:
|
||||
[[nodiscard]] bool SetTask(const std::function<void()>& task);
|
||||
/// @returns false if the thread is busy.
|
||||
/// @param task The task to run.
|
||||
/// @note Task is passed by value on purpose so that the original shared_ptr to the task is still good.
|
||||
[[nodiscard]] bool SetTask(std::shared_ptr<TaskBase> task);
|
||||
|
||||
/// @returns false if the thread is busy.
|
||||
/// @param task The task to run.
|
||||
/// @note Task is passed by reference because the frame after next is going to copy it anyway.
|
||||
[[nodiscard]] bool SetTask(const std::function<void()>& task);
|
||||
|
||||
/// @returns true if the thread is currently doing work.
|
||||
/// @note SetTask will return false if the thread is busy.
|
||||
[[nodiscard]] bool Busy() const { return busy; }
|
||||
void Join() { if (worker.joinable()) worker.join(); };
|
||||
|
||||
/// Blocks the thread join is called from until execution of this thread exits.
|
||||
/// @note Joining this thread from this thread will raise an exception.
|
||||
void Join();
|
||||
public:
|
||||
/// Create a thread which will initialize and then wait for a task.
|
||||
Thread() { worker = std::thread([this] { this->Runner(); }); }
|
||||
/// Crete a thread which will immediately run a task and then wait for another.
|
||||
explicit Thread(std::shared_ptr<TaskBase> task);
|
||||
/// Crete a thread which will immediately run a task and then wait for another.
|
||||
explicit Thread(const std::function<void()>& task);
|
||||
/// Waits for the current task to finish (if there is one) and destroys the thread.
|
||||
// TODO Avoid spinning.
|
||||
~Thread();
|
||||
};
|
@@ -7,24 +7,46 @@ namespace MultiThreading {
|
||||
class ThreadPool;
|
||||
}
|
||||
|
||||
/// A group of threads to run tasks on.
|
||||
class MultiThreading::ThreadPool {
|
||||
private:
|
||||
std::vector<MultiThreading::Thread*> threads;
|
||||
std::queue<std::shared_ptr<TaskBase>> queue;
|
||||
std::condition_variable queue_condition;
|
||||
std::mutex queue_mutex;
|
||||
private:
|
||||
/// @returns nullptr if the queue is empty.
|
||||
std::shared_ptr<TaskBase> Dequeue();
|
||||
void Runner(const std::shared_ptr<TaskBase>& task);
|
||||
public:
|
||||
/// Set a task to be run on the thread-pool.
|
||||
/// @param task The task to run.
|
||||
void Enqueue(const std::shared_ptr<MultiThreading::TaskBase>& task);
|
||||
|
||||
/// Set a task to be run on the thread-pool.
|
||||
/// @param task The task to run.
|
||||
void Enqueue(const std::function<void()>& task);
|
||||
public:
|
||||
/// @returns The number of threads in the thread pool.
|
||||
[[nodiscard]] unsigned int ThreadCount() const { return threads.size(); }
|
||||
|
||||
/// @returns the number of tasks in the queue
|
||||
/// @note this excludes the tasks currently being run.
|
||||
[[nodiscard]] unsigned int PendingTasks();
|
||||
/// @returns Whether a task you enqueue would have to wait.
|
||||
|
||||
/// @returns Whether a task you enqueue would have to wait for something else to finish.
|
||||
[[nodiscard]] bool Busy();
|
||||
|
||||
/// Uses a condition variable to wait the calling thread until the queue is empty.
|
||||
void WaitQueueEmpty();
|
||||
public:
|
||||
/// Constructs a thread-pool with a given number of threads (or hardware_concurrency as the default)
|
||||
/// @param thread_count The number of threads.
|
||||
/// @note If you do a *lot* of work on the main thread and hardware_concurrency >= 2, You should use hardware_concurrency -1.
|
||||
explicit ThreadPool(unsigned int thread_count = std::thread::hardware_concurrency());
|
||||
|
||||
/// Waits for the threads to empty the queue and destroys the thread pool.
|
||||
/// @note This should be one of the very last things your program does before exiting.
|
||||
// TODO make Enqueue fail during destruction to avoid stalls.
|
||||
~ThreadPool();
|
||||
};
|
6
main.cpp
6
main.cpp
@@ -5,17 +5,13 @@
|
||||
#include <iostream>
|
||||
|
||||
using namespace MultiThreading;
|
||||
|
||||
int32_t some_test_func(int32_t hello) {
|
||||
for (unsigned int i = 0; i < 1000000000; i++) {}
|
||||
std::cout << "task " << hello << " finishes." << std::endl;
|
||||
return rand();
|
||||
}
|
||||
|
||||
void cb(Thread* thread, std::shared_ptr<TaskBase> task) {
|
||||
std::cout << thread->Busy() << std::endl;
|
||||
std::cout << task->Complete() << std::endl;
|
||||
}
|
||||
|
||||
/*
|
||||
int main() {
|
||||
// Each task you create can be run by a thread only once. It's marked as complete after.
|
||||
|
@@ -66,3 +66,8 @@ Thread::Thread(const std::function<void()>& task) {
|
||||
if (!SetTask(task))
|
||||
throw std::runtime_error("Thread constructor failure.");
|
||||
}
|
||||
|
||||
void Thread::Join() {
|
||||
if (worker.joinable())
|
||||
worker.join();
|
||||
}
|
||||
|
@@ -3,6 +3,7 @@
|
||||
using namespace MultiThreading;
|
||||
|
||||
ThreadPool::ThreadPool(unsigned int thread_count) {
|
||||
|
||||
for (unsigned int i = 0; i < thread_count; i++)
|
||||
threads.push_back(new Thread());
|
||||
}
|
||||
@@ -33,6 +34,10 @@ std::shared_ptr<TaskBase> ThreadPool::Dequeue() {
|
||||
|
||||
auto task = queue.front();
|
||||
queue.pop();
|
||||
|
||||
if (queue.empty())
|
||||
queue_condition.notify_all();
|
||||
|
||||
return task;
|
||||
}
|
||||
|
||||
@@ -58,8 +63,7 @@ unsigned int ThreadPool::PendingTasks() {
|
||||
|
||||
ThreadPool::~ThreadPool() {
|
||||
// Wait for all queued tasks to be running.
|
||||
// TODO avoid spin-loop here.
|
||||
while (PendingTasks() != 0) {}
|
||||
WaitQueueEmpty();
|
||||
|
||||
// delete t waits for the thread to exit gracefully.
|
||||
for (auto* t: threads)
|
||||
@@ -76,3 +80,8 @@ bool ThreadPool::Busy() {
|
||||
|
||||
return all_busy;
|
||||
}
|
||||
|
||||
void ThreadPool::WaitQueueEmpty() {
|
||||
std::unique_lock<std::mutex> lock(queue_mutex);
|
||||
queue_condition.wait(lock, [this]{ return queue.empty(); });
|
||||
}
|
||||
|
Reference in New Issue
Block a user