You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

217 lines
5.8 KiB
Lua

local Stateful = {
_VERSION = 'Stateful 1.0.5 (2017-08)',
_DESCRIPTION = 'Stateful classes for middleclass',
_URL = 'https://github.com/kikito/stateful.lua',
}
-- requires middleclass >2.0
Stateful.static = {}
local _callbacks = {
enteredState = 1,
exitedState = 1,
pushedState = 1,
poppedState = 1,
pausedState = 1,
continuedState = 1,
}
local _BaseState = {}
local function _addStatesToClass(klass, superStates)
klass.static.states = {}
for stateName, state in pairs(superStates or {}) do
klass:addState(stateName, state)
end
end
local function _getStatefulMethod(instance, name)
if not _callbacks[name] then
local stack = rawget(instance, '__stateStack')
if not stack then
return
end
for i = #stack, 1, -1 do
if stack[i][name] then
return stack[i][name]
end
end
end
end
local function _getNewInstanceIndex(prevIndex)
if type(prevIndex) == 'function' then
return function(instance, name)
return _getStatefulMethod(instance, name) or prevIndex(instance, name)
end
end
return function(instance, name)
return _getStatefulMethod(instance, name) or prevIndex[name]
end
end
local function _getNewAllocateMethod(oldAllocateMethod)
return function(klass, ...)
local instance = oldAllocateMethod(klass, ...)
instance.__stateStack = {}
return instance
end
end
local function _modifyInstanceIndex(klass)
klass.__instanceDict.__index = _getNewInstanceIndex(klass.__instanceDict.__index)
end
local function _getNewSubclassMethod(prevSubclass)
return function(klass, name)
local subclass = prevSubclass(klass, name)
_addStatesToClass(subclass, klass.states)
_modifyInstanceIndex(subclass)
return subclass
end
end
local function _modifySubclassMethod(klass)
klass.static.subclass = _getNewSubclassMethod(klass.static.subclass)
end
local function _modifyAllocateMethod(klass)
klass.static.allocate = _getNewAllocateMethod(klass.static.allocate)
end
local function _assertType(val, name, expected_type, type_to_s)
if type(val) ~= expected_type then
error("Expected " .. name .. " to be of type " .. (type_to_s or expected_type) .. ". Was " .. tostring(val) ..
"(" .. type(val) .. ")")
end
end
local function _assertInexistingState(klass, stateName)
if klass.states[stateName] ~= nil then
error("State " .. tostring(stateName) .. " already exists on " .. tostring(klass))
end
end
local function _assertExistingState(self, state, stateName)
if not state then
error("The state " .. stateName .. " was not found in " .. tostring(self.class))
end
end
local function _invokeCallback(self, state, callbackName, ...)
if state and state[callbackName] then
state[callbackName](self, ...)
end
end
local function _getCurrentState(self)
return self.__stateStack[#self.__stateStack]
end
local function _getStateFromClassByName(self, stateName)
local state = self.class.static.states[stateName]
_assertExistingState(self, state, stateName)
return state
end
local function _getStateIndexFromStackByName(self, stateName)
if stateName == nil then
return #self.__stateStack
end
local target = _getStateFromClassByName(self, stateName)
for i = #self.__stateStack, 1, -1 do
if self.__stateStack[i] == target then
return i
end
end
end
local function _getStateName(self, target)
for name, state in pairs(self.class.static.states) do
if state == target then
return name
end
end
end
function Stateful:included(klass)
_addStatesToClass(klass)
_modifyInstanceIndex(klass)
_modifySubclassMethod(klass)
_modifyAllocateMethod(klass)
end
function Stateful.static:addState(stateName, superState)
superState = superState or _BaseState
_assertType(stateName, 'stateName', 'string')
_assertInexistingState(self, stateName)
self.static.states[stateName] = setmetatable({}, {
__index = superState,
})
return self.static.states[stateName]
end
function Stateful:gotoState(stateName, ...)
self:popAllStates(...)
if stateName == nil then
self.__stateStack = {}
else
_assertType(stateName, 'stateName', 'string', 'string or nil')
local newState = _getStateFromClassByName(self, stateName)
self.__stateStack = {newState}
_invokeCallback(self, newState, 'enteredState', ...)
end
end
function Stateful:pushState(stateName, ...)
local oldState = _getCurrentState(self)
_invokeCallback(self, oldState, 'pausedState')
local newState = _getStateFromClassByName(self, stateName)
table.insert(self.__stateStack, newState)
_invokeCallback(self, newState, 'pushedState', ...)
_invokeCallback(self, newState, 'enteredState', ...)
end
function Stateful:popState(stateName, ...)
local oldStateIndex = _getStateIndexFromStackByName(self, stateName)
local oldState
if oldStateIndex then
oldState = self.__stateStack[oldStateIndex]
_invokeCallback(self, oldState, 'poppedState', ...)
_invokeCallback(self, oldState, 'exitedState', ...)
table.remove(self.__stateStack, oldStateIndex)
end
local newState = _getCurrentState(self)
if oldState ~= newState then
_invokeCallback(self, newState, 'continuedState', ...)
end
end
function Stateful:popAllStates(...)
local size = #self.__stateStack
for _ = 1, size do
self:popState(nil, ...)
end
end
function Stateful:getStateStackDebugInfo()
local info = {}
local state
for i = #self.__stateStack, 1, -1 do
state = self.__stateStack[i]
table.insert(info, _getStateName(self, state))
end
return info
end
return Stateful