Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 69 additions & 22 deletions include/thread_pool/thread_pool.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,31 +6,21 @@
#include <future>
#include <mutex>
#include <queue>
#include <thread>
#include <vector>

namespace tp {

class thread_pool {
public:
/// Parameters
struct Params {
thread::Params thread_params{};
size_t size = 1;
};
namespace details {

template <typename ThreadType>
class thread_pool_base {
public:
/// Constructor
thread_pool(const Params &params) noexcept;
thread_pool_base() = default;

/// Destructor
~thread_pool() noexcept;

/// Non movable
thread_pool(thread_pool &&other) = delete;
thread_pool &operator=(thread_pool &&other) = delete;

/// Non copyable
thread_pool(const thread_pool &other) = delete;
thread_pool &operator=(const thread_pool &other) = delete;
virtual ~thread_pool_base() noexcept;

/// Push a task to the task queue
template <typename Callable, typename ... Args>
Expand Down Expand Up @@ -60,16 +50,20 @@ class thread_pool {
/// Current queue size
size_t qsize() const noexcept;

protected:
/// Pool of threads
std::vector<ThreadType> threads_;

/// Worker thread, waits to dequeue tasks from the queue
void worker() noexcept;

private:
using Callback = thread::Callback;
using Task = std::packaged_task<details::function_type<Callback>::type>;

/// Flag to stop all threads
std::atomic<bool> kill_{false};

/// Pool of threads
std::vector<thread> threads_;

/// Lock, protects the queue and the condition variables
mutable std::mutex lock_;

Expand All @@ -82,9 +76,62 @@ class thread_pool {

/// Cancels and joins all threads
void join() noexcept;
};

/// Worker thread, waits to dequeu tasks from the queue
void worker() noexcept;
} // namespace details

template <typename ThreadType>
class thread_pool;

/// std::thread based thread pool
template <>
class thread_pool<std::thread> : public details::thread_pool_base<std::thread> {
public:
/// Parameters
struct Params {
size_t size = 1;
};

/// Constructor
thread_pool(const Params &params) noexcept {
for (size_t t = 0; t < params.size; t++) {
threads_.push_back(std::thread(&thread_pool::worker, this));
}
}

/// Non movable
thread_pool(thread_pool &&other) = delete;
thread_pool &operator=(thread_pool &&other) = delete;

/// Non copyable
thread_pool(const thread_pool &other) = delete;
thread_pool &operator=(const thread_pool &other) = delete;
};

/// tp::thread based thread pool
template <>
class thread_pool<tp::thread> : public details::thread_pool_base<tp::thread> {
public:
/// Parameters
struct Params {
thread::Params thread_params{};
size_t size = 1;
};

/// Constructor
thread_pool(const Params &params) noexcept {
for (size_t t = 0; t < params.size; t++) {
threads_.push_back(tp::thread(params.thread_params, &thread_pool::worker, this));
}
}

/// Non movable
thread_pool(thread_pool &&other) = delete;
thread_pool &operator=(thread_pool &&other) = delete;

/// Non copyable
thread_pool(const thread_pool &other) = delete;
thread_pool &operator=(const thread_pool &other) = delete;
};

} // namespace tp
36 changes: 21 additions & 15 deletions src/thread_pool.cpp
Original file line number Diff line number Diff line change
@@ -1,18 +1,14 @@
#include "thread_pool/thread_pool.h"

namespace tp {
namespace tp::details {

thread_pool::thread_pool(const Params &params) noexcept {
for (size_t t = 0; t < params.size; t++) {
threads_.push_back(thread(params.thread_params, &thread_pool::worker, this));
}
}

thread_pool::~thread_pool() noexcept {
template <typename ThreadType>
thread_pool_base<ThreadType>::~thread_pool_base() noexcept {
join();
}

void thread_pool::join(const bool finish_queue) noexcept {
template <typename ThreadType>
void thread_pool_base<ThreadType>::join(const bool finish_queue) noexcept {
// Block until queue is empty
if (finish_queue) {
std::unique_lock lock(lock_);
Expand All @@ -24,24 +20,30 @@ void thread_pool::join(const bool finish_queue) noexcept {
join();
}

size_t thread_pool::size() const noexcept {
template <typename ThreadType>
size_t thread_pool_base<ThreadType>::size() const noexcept {
return threads_.size();
}

size_t thread_pool::qsize() const noexcept {
template <typename ThreadType>
size_t thread_pool_base<ThreadType>::qsize() const noexcept {
std::scoped_lock lock(lock_);
return q_.size();
}

void thread_pool::join() noexcept {
template <typename ThreadType>
void thread_pool_base<ThreadType>::join() noexcept {
kill_ = true;
q_push_notifier_.notify_all();
for (auto &thread : threads_) {
thread.join();
if (thread.joinable()) {
thread.join();
}
}
}

void thread_pool::worker() noexcept {
template <typename ThreadType>
void thread_pool_base<ThreadType>::worker() noexcept {
auto dequeue = [this] {
Task task{};

Expand Down Expand Up @@ -78,4 +80,8 @@ void thread_pool::worker() noexcept {
}
}

} // namespace tp
/// Explicitly instantiate only these 2 acceptable implementations of thread
template class thread_pool_base<std::thread>;
template class thread_pool_base<tp::thread>;

} // namespace tp::details
34 changes: 17 additions & 17 deletions test/test_thread_pool.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,29 +5,29 @@

#include <array>

using namespace tp;
#define TEMPLATE_TYPES_UNDER_TEST std::thread , tp::thread

TEST_CASE("thread_pool::DefaultConstructible", "[thread_pool]") {
thread_pool tp({});
TEMPLATE_TEST_CASE("thread_pool::DefaultConstructible", "[thread_pool]", TEMPLATE_TYPES_UNDER_TEST) {
tp::thread_pool<TestType> tp({});
}

TEST_CASE("thread_pool::100ThreadsConstruction", "[thread_pool]") {
thread_pool tp({.size = 100});
TEMPLATE_TEST_CASE("thread_pool::100ThreadsConstruction", "[thread_pool]", TEMPLATE_TYPES_UNDER_TEST) {
tp::thread_pool<TestType> tp({.size = 100});
}

TEST_CASE("thread_pool::1000ThreadsConstruction", "[thread_pool]") {
thread_pool tp({.size = 1000});
TEMPLATE_TEST_CASE("thread_pool::1000ThreadsConstruction", "[thread_pool]", TEMPLATE_TYPES_UNDER_TEST) {
tp::thread_pool<TestType> tp({.size = 1000});
}

TEST_CASE("thread_pool::10000ThreadsConstruction", "[thread_pool]") {
thread_pool tp({.size = 10000});
TEMPLATE_TEST_CASE("thread_pool::10000ThreadsConstruction", "[thread_pool]", TEMPLATE_TYPES_UNDER_TEST) {
tp::thread_pool<TestType> tp({.size = 10000});
}

TEST_CASE("thread_pool::Work", "[thread_pool]") {
TEMPLATE_TEST_CASE("thread_pool::Work", "[thread_pool]", TEMPLATE_TYPES_UNDER_TEST) {
constexpr size_t kNumTasks = 1000;
constexpr size_t kPoolSize = 100;

thread_pool tp({.size = kPoolSize});
tp::thread_pool<TestType> tp({.size = kPoolSize});

SECTION("FreeFunction") {
std::atomic<size_t> count = 0;
Expand Down Expand Up @@ -68,11 +68,11 @@ TEST_CASE("thread_pool::Work", "[thread_pool]") {
}
}

TEST_CASE("thread_pool::OneBurst", "[thread_pool]") {
TEMPLATE_TEST_CASE("thread_pool::OneBurst", "[thread_pool]", TEMPLATE_TYPES_UNDER_TEST) {
constexpr size_t kNumTasks = 1000;
constexpr size_t kPoolSize = 100;

thread_pool tp({.size = kPoolSize});
tp::thread_pool<TestType> tp({.size = kPoolSize});

std::atomic<size_t> count = 0;
for (size_t ii = 0; ii < kNumTasks; ii++) {
Expand All @@ -90,11 +90,11 @@ TEST_CASE("thread_pool::OneBurst", "[thread_pool]") {
}
}

TEST_CASE("thread_pool::RepeatedBursts", "[thread_pool]") {
TEMPLATE_TEST_CASE("thread_pool::RepeatedBursts", "[thread_pool]", TEMPLATE_TYPES_UNDER_TEST) {
constexpr size_t kNumTasks = 1000;
constexpr size_t kPoolSize = 100;

thread_pool tp({.size = kPoolSize});
tp::thread_pool<TestType> tp({.size = kPoolSize});

std::atomic<size_t> count = 0;
for (size_t round = 0; round < (kNumTasks / kPoolSize); round++) {
Expand All @@ -114,15 +114,15 @@ TEST_CASE("thread_pool::RepeatedBursts", "[thread_pool]") {
}
}

TEST_CASE("thread_pool::Future", "[thread_pool]") {
TEMPLATE_TEST_CASE("thread_pool::Future", "[thread_pool]", TEMPLATE_TYPES_UNDER_TEST) {
std::atomic<bool> signal{false};
auto wait_for_signal = [&signal] {
while (!signal) {
std::this_thread::sleep_for(std::chrono::milliseconds(1));
}
};

thread_pool tp({.size = 1});
tp::thread_pool<TestType> tp({.size = 1});
auto future = tp.push(wait_for_signal);

REQUIRE(std::future_status::timeout == future.wait_for(std::chrono::milliseconds(10)));
Expand Down