You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
84 lines
2.5 KiB
Lua
84 lines
2.5 KiB
Lua
--- Module for creating a thread pool, based on Lua Lanes.
|
|
-- @module lqc.threading.thread_pool
|
|
-- @alias ThreadPool
|
|
local MsgProcessor = require 'lqc.threading.msg_processor'
|
|
local map = require 'lqc.helpers.map'
|
|
local lanes = require('lanes').configure {
|
|
with_timers = false,
|
|
}
|
|
|
|
--- Checks if x is a positive integer (excluding 0)
|
|
-- @param x value to be checked
|
|
-- @return true if x is a non-zero positive integer; otherwise false.
|
|
local function is_positive_integer(x)
|
|
return type(x) == 'number' and x % 1 == 0 and x > 0
|
|
end
|
|
|
|
--- Checks if the thread pool args are valid.
|
|
-- @return nil; raises an error if invalid args are passed in.
|
|
local function check_threadpool_args(num_threads)
|
|
if not is_positive_integer(num_threads) then
|
|
error 'num_threads should be an integer > 0'
|
|
end
|
|
end
|
|
|
|
--- Creates and starts a thread.
|
|
-- @param func Function the thread should run after startup
|
|
-- @return a new thread object
|
|
local function make_thread(func)
|
|
return lanes.gen('*', func)()
|
|
end
|
|
|
|
local ThreadPool = {}
|
|
local ThreadPool_mt = {
|
|
__index = ThreadPool,
|
|
}
|
|
|
|
--- Creates a new thread pool with a specific number of threads
|
|
-- @param num_threads Amount of the threads the pool should have
|
|
-- @return thread pool with a specific number of threads
|
|
function ThreadPool.new(num_threads)
|
|
check_threadpool_args(num_threads)
|
|
local linda = lanes.linda()
|
|
local thread_pool = {
|
|
threads = {},
|
|
linda = linda,
|
|
numjobs = 0,
|
|
}
|
|
|
|
for _ = 1, num_threads do
|
|
table.insert(thread_pool.threads, make_thread(MsgProcessor.new(linda)))
|
|
end
|
|
return setmetatable(thread_pool, ThreadPool_mt)
|
|
end
|
|
|
|
--- Schedules a task to a thread in the thread pool
|
|
-- @param task A function that should be run on the thread
|
|
function ThreadPool:schedule(task)
|
|
self.numjobs = self.numjobs + 1
|
|
self.linda:send(nil, MsgProcessor.TASK_TAG, task)
|
|
end
|
|
|
|
--- Stops all threads in the threadpool. Blocks until all threads are finished
|
|
-- @return a table containing all results (in no specific order)
|
|
function ThreadPool:join()
|
|
map(self.threads, function()
|
|
self:schedule(MsgProcessor.STOP_VALUE)
|
|
end)
|
|
map(self.threads, function(thread)
|
|
thread:join()
|
|
end)
|
|
|
|
local results = {}
|
|
for _ = 1, self.numjobs - #self.threads do -- don't count stop job at end
|
|
local _, result = self.linda:receive(nil, MsgProcessor.RESULT_TAG)
|
|
if result ~= MsgProcessor.VOID_RESULT then
|
|
table.insert(results, result)
|
|
end
|
|
end
|
|
return results
|
|
end
|
|
|
|
return ThreadPool
|
|
|