Small fixes. Make the demo nicer.

This commit is contained in:
2025-04-09 21:10:28 -04:00
parent 3b9ca2e144
commit d20bda0d4e
5 changed files with 43 additions and 43 deletions

View File

@@ -14,7 +14,6 @@ namespace MultiThreading {
class MultiThreading::Thread {
private:
std::function<void(Thread* thread, std::shared_ptr<TaskBase> task)> on_task_complete_callback = nullptr;
std::shared_ptr<TaskBase> current_task = nullptr;
std::atomic<bool> stop = false;
std::condition_variable cv;
@@ -28,7 +27,6 @@ public:
[[nodiscard]] bool SetTask(std::shared_ptr<TaskBase> task);
[[nodiscard]] bool Busy() const { return busy; }
void Join() { if (worker.joinable()) worker.join(); };
void SetTaskCompletionCallback(std::function<void(Thread*, std::shared_ptr<TaskBase>)> callback) { on_task_complete_callback = std::move(callback); };
public:
Thread() { worker = std::thread([this] { this->Runner(); }); }
explicit Thread(std::shared_ptr<TaskBase> task);

View File

@@ -21,9 +21,10 @@ public:
void Enqueue(const std::function<void()>& task);
public:
[[nodiscard]] unsigned int ThreadCount() const { return threads.size(); }
[[nodiscard]] unsigned int QueueSize();
[[nodiscard]] unsigned int PendingTasks();
/// @returns Whether a task you enqueue would have to wait.
[[nodiscard]] bool Busy();
public:
ThreadPool();
explicit ThreadPool(unsigned int thread_count);
explicit ThreadPool(unsigned int thread_count = std::thread::hardware_concurrency());
~ThreadPool();
};

View File

@@ -6,9 +6,9 @@
using namespace MultiThreading;
int32_t some_test_func(int32_t hello) {
for (unsigned int i = 0; i < 50; i++)
std::cout << i << std::endl;
return hello;
for (unsigned int i = 0; i < 1000000000; i++) {}
std::cout << "task " << hello << " finishes." << std::endl;
return rand();
}
void cb(Thread* thread, std::shared_ptr<TaskBase> task) {
@@ -16,6 +16,7 @@ void cb(Thread* thread, std::shared_ptr<TaskBase> task) {
std::cout << task->Complete() << std::endl;
}
/*
int main() {
// Each task you create can be run by a thread only once. It's marked as complete after.
// If you're running a lambda or std::function directly on the thread, It can be used multiple times.
@@ -42,26 +43,21 @@ int main() {
//std::cout << a << std::endl;
}
*/
/*
int main() {
ThreadPool thread_pool(1);
srand(time(nullptr));
auto some_task_1 = Task<void>::Create(([] { return some_test_func(1); }));
auto some_task_2 = Task<void>::Create(([] { return some_test_func(1); }));
auto some_task_3 = Task<void>::Create(([] { return some_test_func(1); }));
auto some_task_4 = Task<void>::Create(([] { return some_test_func(1); }));
int32_t task_completion_count = 0;
auto* thread_pool = new ThreadPool();
thread_pool.Enqueue(some_task_1);
thread_pool.Enqueue(some_task_2);
thread_pool.Enqueue(some_task_3);
thread_pool.Enqueue(some_task_4);
for (unsigned int i = 0; i < 128; i++) {
auto some_task = Task<int32_t>::Create(([i] { return some_test_func(i + 1); }), &task_completion_count);
thread_pool->Enqueue(some_task);
}
while (!some_task_4->Complete()) {}
std::cout << thread_pool.ThreadCount() << std::endl;
//delete thread_pool;
}
*/
delete thread_pool;
std::cout << "The returned random value was: " << task_completion_count << std::endl;
}

View File

@@ -46,9 +46,6 @@ void Thread::Runner() {
current_task->Run();
if (on_task_complete_callback)
on_task_complete_callback(this, current_task);
lock.lock();
current_task = nullptr;

View File

@@ -7,22 +7,17 @@ ThreadPool::ThreadPool(unsigned int thread_count) {
threads.push_back(new Thread());
}
ThreadPool::ThreadPool() {
for (unsigned int i = 0; i < std::thread::hardware_concurrency(); i++)
threads.push_back(new Thread());
}
void ThreadPool::Enqueue(const std::shared_ptr<MultiThreading::TaskBase>& task) {
std::lock_guard<std::mutex> lock(queue_mutex);
// Assign it immediately if there's no wait and a thread open.
if (queue.empty()) {
for (auto *t: threads) {
for (auto* t: threads) {
if (t->Busy())
continue;
if (!t->SetTask( [this, task] (){ Runner(task); } ))
throw std::runtime_error("There was an error while setting up the task to run on the thread.");
throw std::runtime_error("There was a collision while putting the task to the thread.");
return;
}
}
@@ -46,25 +41,38 @@ void ThreadPool::Enqueue(const std::function<void()>& task) {
}
void ThreadPool::Runner(const std::shared_ptr<TaskBase>& task) {
if (!task->Complete())
task->Run();
auto running_task = task;
auto next_task = Dequeue();
if (!next_task)
return;
Runner(next_task);
while (running_task) {
if (!running_task->Complete())
running_task->Run();
running_task = Dequeue();
}
}
unsigned int ThreadPool::QueueSize() {
unsigned int ThreadPool::PendingTasks() {
std::lock_guard<std::mutex> lock(queue_mutex);
return queue.size();
}
ThreadPool::~ThreadPool() {
// Wait for all tasks to be running.
while (QueueSize() != 0) {}
// Wait for all queued tasks to be running.
// TODO avoid spin-loop here.
while (PendingTasks() != 0) {}
// delete t waits for the thread to exit gracefully.
for (auto* t: threads)
delete t;
}
bool ThreadPool::Busy() {
if (PendingTasks() != 0)
return true;
bool all_busy = true;
for (auto* t : threads)
if (!t->Busy()) { all_busy = false; break; }
return all_busy;
}