You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
120 lines
3.5 KiB
Lua
120 lines
3.5 KiB
Lua
--- Replace functions of table or upvalue.
|
|
-- Search for the old functions and replace them with new ones.
|
|
local M = {}
|
|
|
|
-- Objects whose functions have been replaced already.
|
|
-- Each objects need to be replaced once.
|
|
local replaced_obj = {}
|
|
|
|
-- Map old functions to new functions.
|
|
-- Used to replace functions finally.
|
|
-- Set to hotfix.updated_func_map.
|
|
local updated_func_map = {}
|
|
|
|
-- Do not update and replace protected objects.
|
|
-- Set to hotfix.protected.
|
|
local protected = {}
|
|
|
|
local replace_functions -- forward declare
|
|
|
|
-- Replace all updated functions in upvalues of function object.
|
|
local function replace_functions_in_upvalues(function_object)
|
|
local obj = function_object
|
|
assert("function" == type(obj))
|
|
assert(not protected[obj])
|
|
assert(obj ~= updated_func_map)
|
|
|
|
for i = 1, math.huge do
|
|
local name, value = debug.getupvalue(obj, i)
|
|
if not name then
|
|
return
|
|
end
|
|
local new_func = updated_func_map[value]
|
|
if new_func then
|
|
assert("function" == type(value))
|
|
debug.setupvalue(obj, i, new_func)
|
|
else
|
|
replace_functions(value)
|
|
end
|
|
end -- for
|
|
assert(false, "Can not reach here!")
|
|
end -- replace_functions_in_upvalues()
|
|
|
|
-- Replace all updated functions in the table.
|
|
local function replace_functions_in_table(table_object)
|
|
local obj = table_object
|
|
assert("table" == type(obj))
|
|
assert(not protected[obj])
|
|
assert(obj ~= updated_func_map)
|
|
|
|
replace_functions(debug.getmetatable(obj))
|
|
local new = {} -- to assign new fields
|
|
for k, v in pairs(obj) do
|
|
local new_k = updated_func_map[k]
|
|
local new_v = updated_func_map[v]
|
|
if new_k then
|
|
obj[k] = nil -- delete field
|
|
new[new_k] = new_v or v
|
|
else
|
|
-- obj[k] = new_v or v
|
|
local ok, err = pcall(rawset, obj, k, new_v or v)
|
|
if not ok then
|
|
-- print("replaced failed:", k, err)
|
|
end
|
|
replace_functions(k)
|
|
end
|
|
if not new_v then
|
|
replace_functions(v)
|
|
end
|
|
end -- for k, v
|
|
for k, v in pairs(new) do
|
|
obj[k] = v
|
|
end
|
|
end -- replace_functions_in_table()
|
|
|
|
-- Replace all updated functions.
|
|
-- Record all replaced objects in replaced_obj.
|
|
function replace_functions(obj)
|
|
if protected[obj] then
|
|
return
|
|
end
|
|
local obj_type = type(obj)
|
|
if "function" ~= obj_type and "table" ~= obj_type then
|
|
return
|
|
end
|
|
if replaced_obj[obj] then
|
|
return
|
|
end
|
|
replaced_obj[obj] = true
|
|
assert(obj ~= updated_func_map)
|
|
|
|
if "function" == obj_type then
|
|
replace_functions_in_upvalues(obj)
|
|
else -- table
|
|
replace_functions_in_table(obj)
|
|
end
|
|
end -- replace_functions(obj)
|
|
|
|
--- Replace all old functions with new ones.
|
|
-- Replace in new_obj, _G and debug.getregistry().
|
|
-- a_protected is a list of protected object.
|
|
-- an_updated_func_map is a map from old function to new function.
|
|
-- new_obj is the newly loaded module.
|
|
function M.replace_all(a_protected, an_updated_func_map, new_obj)
|
|
protected = a_protected
|
|
updated_func_map = an_updated_func_map
|
|
assert(type(protected) == "table")
|
|
assert(type(updated_func_map) == "table")
|
|
if nil == next(updated_func_map) then
|
|
return
|
|
end
|
|
|
|
replaced_obj = {}
|
|
replace_functions(new_obj) -- new_obj may be not in _G
|
|
replace_functions(_G)
|
|
replace_functions(debug.getregistry())
|
|
replaced_obj = {}
|
|
end -- M.replace_all()
|
|
|
|
return M
|