🔧 build: 增加 luassert 测试 库

develop
cloudfreexiao 5 years ago
parent 4e3704ac97
commit 803f7326c7

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

@ -0,0 +1,473 @@
--[[
Minimal test framework for Lua.
lester - v0.1.2 - 15/Feb/2021
Eduardo Bart - edub4rt@gmail.com
https://github.com/edubart/lester
Minimal Lua test framework.
See end of file for LICENSE.
]] --[[--
Lester is a minimal unit testing framework for Lua with a focus on being simple to use.
## Features
* Minimal, just one file.
* Self contained, no external dependencies.
* Simple and hackable when needed.
* Use `describe` and `it` blocks to describe tests.
* Supports `before` and `after` handlers.
* Colored output.
* Configurable via the script or with environment variables.
* Quiet mode, to use in live development.
* Optionally filter tests by name.
* Show traceback on errors.
* Show time to complete tests.
* Works with Lua 5.1+.
* Efficient.
## Usage
Copy `lester.lua` file to a project and require it,
which returns a table that includes all of the functionality:
```lua
local lester = require 'lester'
local describe, it, expect = lester.describe, lester.it, lester.expect
-- Customize lester configuration.
lester.show_traceback = false
describe('my project', function()
lester.before(function()
-- This function is run before every test.
end)
describe('module1', function() -- Describe blocks can be nested.
it('feature1', function()
expect.equal('something', 'something') -- Pass.
end)
it('feature2', function()
expect.truthy(false) -- Fail.
end)
end)
end)
lester.report() -- Print overall statistic of the tests run.
lester.exit() -- Exit with success if all tests passed.
```
## Customizing output with environment variables
To customize the output of lester externally,
you can set the following environment variables before running a test suite:
* `LESTER_QUIET="true"`, omit print of passed tests.
* `LESTER_COLORED="false"`, disable colored output.
* `LESTER_SHOW_TRACEBACK="false"`, disable traceback on test failures.
* `LESTER_SHOW_ERROR="false"`, omit print of error description of failed tests.
* `LESTER_STOP_ON_FAIL="true"`, stop on first test failure.
* `LESTER_UTF8TERM="false"`, disable printing of UTF-8 characters.
* `LESTER_FILTER="some text"`, filter the tests that should be run.
Note that these configurations can be changed via script too, check the documentation.
]] -- Returns whether the terminal supports UTF-8 characters.
local function is_utf8term()
local lang = os.getenv('LANG')
return (lang and lang:lower():match('utf%-8$')) and true or false
end
-- Returns whether a system environment variable is "true".
local function getboolenv(varname, default)
local val = os.getenv(varname)
if val == 'true' then
return true
elseif val == 'false' then
return false
end
return default
end
-- The lester module.
local lester = {
--- Weather lines of passed tests should not be printed. False by default.
quiet = getboolenv('LESTER_QUIET', false),
--- Weather the output should be colorized. True by default.
colored = getboolenv('LESTER_COLORED', true),
--- Weather a traceback must be shown on test failures. True by default.
show_traceback = getboolenv('LESTER_SHOW_TRACEBACK', true),
--- Weather the error description of a test failure should be shown. True by default.
show_error = getboolenv('LESTER_SHOW_ERROR', true),
--- Weather test suite should exit on first test failure. False by default.
stop_on_fail = getboolenv('LESTER_STOP_ON_FAIL', false),
--- Weather we can print UTF-8 characters to the terminal. True by default when supported.
utf8term = getboolenv('LESTER_UTF8TERM', is_utf8term()),
--- A string with a lua pattern to filter tests. Nil by default.
filter = os.getenv('LESTER_FILTER'),
--- Function to retrieve time in seconds with milliseconds precision, `os.clock` by default.
seconds = os.clock,
}
-- Variables used internally for the lester state.
local lester_start = nil
local last_succeeded = false
local level = 0
local successes = 0
local total_successes = 0
local failures = 0
local total_failures = 0
local start = 0
local befores = {}
local afters = {}
local names = {}
-- Color codes.
local color_codes = {
reset = string.char(27) .. '[0m',
bright = string.char(27) .. '[1m',
red = string.char(27) .. '[31m',
green = string.char(27) .. '[32m',
blue = string.char(27) .. '[34m',
magenta = string.char(27) .. '[35m',
}
-- Colors table, returning proper color code if colored mode is enabled.
local colors = setmetatable({}, {
__index = function(_, key)
return lester.colored and color_codes[key] or ''
end,
})
--- Table of terminal colors codes, can be customized.
lester.colors = colors
--- Describe a block of tests, which consists in a set of tests.
-- Describes can be nested.
-- @param name A string used to describe the block.
-- @param func A function containing all the tests or other describes.
function lester.describe(name, func)
if level == 0 then -- Get start time for top level describe blocks.
start = lester.seconds()
if not lester_start then
lester_start = start
end
end
-- Setup describe block variables.
failures = 0
successes = 0
level = level + 1
names[level] = name
-- Run the describe block.
func()
-- Cleanup describe block.
afters[level] = nil
befores[level] = nil
names[level] = nil
level = level - 1
-- Pretty print statistics for top level describe block.
if level == 0 and not lester.quiet and (successes > 0 or failures > 0) then
local io_write = io.write
local colors_reset, colors_green = colors.reset, colors.green
io_write(failures == 0 and colors_green or colors.red, '[====] ', colors.magenta, name, colors_reset, ' | ',
colors_green, successes, colors_reset, ' successes / ')
if failures > 0 then
io_write(colors.red, failures, colors_reset, ' failures / ')
end
io_write(colors.bright, string.format('%.6f', lester.seconds() - start), colors_reset, ' seconds\n')
end
end
-- Error handler used to get traceback for errors.
local function xpcall_error_handler(err)
return debug.traceback(tostring(err), 2)
end
-- Pretty print the line on the test file where an error happened.
local function show_error_line(err)
local info = debug.getinfo(3)
local io_write = io.write
local colors_reset = colors.reset
local short_src, currentline = info.short_src, info.currentline
io_write(' (', colors.blue, short_src, colors_reset, ':', colors.bright, currentline, colors_reset)
if err and lester.show_traceback then
local fnsrc = short_src .. ':' .. currentline
for cap1, cap2 in err:gmatch('\t[^\n:]+:(%d+): in function <([^>]+)>\n') do
if cap2 == fnsrc then
io_write('/', colors.bright, cap1, colors_reset)
break
end
end
end
io_write(')')
end
-- Pretty print the test name, with breadcrumb for the describe blocks.
local function show_test_name(name)
local io_write = io.write
local colors_reset = colors.reset
for _, descname in ipairs(names) do
io_write(colors.magenta, descname, colors_reset, ' | ')
end
io_write(colors.bright, name, colors_reset)
end
--- Declare a test, which consists of a set of assertions.
-- @param name A name for the test.
-- @param func The function containing all assertions.
function lester.it(name, func)
-- Skip the test if it does not match the filter.
if lester.filter then
local fullname = table.concat(names, ' | ') .. ' | ' .. name
if not fullname:match(lester.filter) then
return
end
end
-- Execute before handlers.
for _, levelbefores in ipairs(befores) do
for _, beforefn in ipairs(levelbefores) do
beforefn(name)
end
end
-- Run the test, capturing errors if any.
local success, err
if lester.show_traceback then
success, err = xpcall(func, xpcall_error_handler)
else
success, err = pcall(func)
if not success and err then
err = tostring(err)
end
end
-- Count successes and failures.
if success then
successes = successes + 1
total_successes = total_successes + 1
else
failures = failures + 1
total_failures = total_failures + 1
end
local io_write = io.write
local colors_reset = colors.reset
-- Print the test run.
if not lester.quiet then -- Show test status and complete test name.
if success then
io_write(colors.green, '[PASS] ', colors_reset)
else
io_write(colors.red, '[FAIL] ', colors_reset)
end
show_test_name(name)
if not success then
show_error_line(err)
end
io_write('\n')
else
if success then -- Show just a character hinting that the test succeeded.
local o = (lester.utf8term and lester.colored) and string.char(226, 151, 143) or 'o'
io_write(colors.green, o, colors_reset)
else -- Show complete test name on failure.
io_write(last_succeeded and '\n' or '', colors.red, '[FAIL] ', colors_reset)
show_test_name(name)
show_error_line(err)
io_write('\n')
end
end
-- Print error message, colorizing its output if possible.
if err and lester.show_error then
if lester.colored then
local errfile, errline, errmsg, rest = err:match('^([^:\n]+):(%d+): ([^\n]+)(.*)')
if errfile and errline and errmsg and rest then
io_write(colors.blue, errfile, colors_reset, ':', colors.bright, errline, colors_reset, ': ')
if errmsg:match('^%w([^:]*)$') then
io_write(colors.red, errmsg, colors_reset)
else
io_write(errmsg)
end
err = rest
end
end
io_write(err, '\n\n')
end
io.flush()
-- Stop on failure.
if not success and lester.stop_on_fail then
if lester.quiet then
io_write('\n')
io.flush()
end
lester.exit()
end
-- Execute after handlers.
for _, levelafters in ipairs(afters) do
for _, afterfn in ipairs(levelafters) do
afterfn(name)
end
end
last_succeeded = success
end
--- Set a function that is called before every test inside a describe block.
-- A single string containing the name of the test about to be run will be passed to `func`.
function lester.before(func)
local levelbefores = befores[level]
if not levelbefores then
levelbefores = {}
befores[level] = levelbefores
end
levelbefores[#levelbefores + 1] = func
end
--- Set a function that is called after every test inside a describe block.
-- A single string containing the name of the test that was finished will be passed to `func`.
-- The function is executed independently if the test passed or failed.
function lester.after(func)
local levelafters = afters[level]
if not levelafters then
levelafters = {}
afters[level] = levelafters
end
levelafters[#levelafters + 1] = func
end
--- Pretty print statistics of all test runs.
-- With total success, total failures and run time in seconds.
function lester.report()
local now = lester.seconds()
local colors_reset = colors.reset
io.write(lester.quiet and '\n' or '', colors.green, total_successes, colors_reset, ' successes / ', colors.red,
total_failures, colors_reset, ' failures / ', colors.bright, string.format('%.6f', now - (lester_start or now)),
colors_reset, ' seconds\n')
io.flush()
return total_failures == 0
end
--- Exit the application with success code if all tests passed, or failure code otherwise.
function lester.exit()
os.exit(total_failures == 0)
end
local expect = {}
--- Expect module, containing utility function for doing assertions inside a test.
lester.expect = expect
--- Check if a function fails with an error.
-- If `expected` is nil then any error is accepted.
-- If `expected` is a string then we check if the error contains that string.
-- If `expected` is anything else then we check if both are equal.
function expect.fail(func, expected)
local ok, err = pcall(func)
if ok then
error('expected function to fail', 2)
elseif expected ~= nil then
local found = expected == err
if not found and type(expected) == 'string' then
found = string.find(tostring(err), expected, 1, true)
end
if not found then
error('expected function to fail\nexpected:\n' .. tostring(expected) .. '\ngot:\n' .. tostring(err), 2)
end
end
end
--- Check if a function does not fail with a error.
function expect.not_fail(func)
local ok, err = pcall(func)
if not ok then
error('expected function to not fail\ngot error:\n' .. tostring(err), 2)
end
end
--- Check if a value is not `nil`.
function expect.exist(v)
if v == nil then
error('expected value to exist\ngot:\n' .. tostring(v), 2)
end
end
--- Check if a value is `nil`.
function expect.not_exist(v)
if v ~= nil then
error('expected value to not exist\ngot:\n' .. tostring(v), 2)
end
end
--- Check if an expression is evaluates to `true`.
function expect.truthy(v)
if not v then
error('expected expression to be true\ngot:\n' .. tostring(v), 2)
end
end
--- Check if an expression is evaluates to `false`.
function expect.falsy(v)
if v then
error('expected expression to be false\ngot:\n' .. tostring(v), 2)
end
end
--- Compare if two values are equal, considering nested tables.
local function strict_eq(t1, t2)
if rawequal(t1, t2) then
return true
end
if type(t1) ~= type(t2) then
return false
end
if type(t1) ~= 'table' then
return t1 == t2
end
if getmetatable(t1) ~= getmetatable(t2) then
return false
end
for k, v1 in pairs(t1) do
if not strict_eq(v1, t2[k]) then
return false
end
end
for k, v2 in pairs(t2) do
if not strict_eq(v2, t1[k]) then
return false
end
end
return true
end
--- Check if two values are equal.
function expect.equal(v1, v2)
if not strict_eq(v1, v2) then
error('expected values to be equal\nfirst value:\n' .. tostring(v1) .. '\nsecond value:\n' .. tostring(v2), 2)
end
end
--- Check if two values are not equal.
function expect.not_equal(v1, v2)
if strict_eq(v1, v2) then
error('expected values to be not equal\nfirst value:\n' .. tostring(v1) .. '\nsecond value:\n' .. tostring(v2),
2)
end
end
return lester
--[[
The MIT License (MIT)
Copyright (c) 2021 Eduardo Bart (https://github.com/edubart)
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
]]

@ -0,0 +1,70 @@
local assert = require('luassert.assert')
local say = require('luassert.say')
-- Example usage:
-- local arr = { "one", "two", "three" }
--
-- assert.array(arr).has.no.holes() -- checks the array to not contain holes --> passes
-- assert.array(arr).has.no.holes(4) -- sets explicit length to 4 --> fails
--
-- local first_hole = assert.array(arr).has.holes(4) -- check array of size 4 to contain holes --> passes
-- assert.equal(4, first_hole) -- passes, as the index of the first hole is returned
-- Unique key to store the object we operate on in the state object
-- key must be unique, to make sure we do not have name collissions in the shared state object
local ARRAY_STATE_KEY = "__array_state"
-- The modifier, to store the object in our state
local function array(state, args, level)
assert(args.n > 0, "No array provided to the array-modifier")
assert(rawget(state, ARRAY_STATE_KEY) == nil, "Array already set")
rawset(state, ARRAY_STATE_KEY, args[1])
return state
end
-- The actual assertion that operates on our object, stored via the modifier
local function holes(state, args, level)
local length = args[1]
local arr = rawget(state, ARRAY_STATE_KEY) -- retrieve previously set object
-- only check against nil, metatable types are allowed
assert(arr ~= nil, "No array set, please use the array modifier to set the array to validate")
if length == nil then
length = 0
for i in pairs(arr) do
if type(i) == "number" and
i > length and
math.floor(i) == i then
length = i
end
end
end
assert(type(length) == "number", "expected array length to be of type 'number', got: "..tostring(length))
-- let's do the actual assertion
local missing
for i = 1, length do
if arr[i] == nil then
missing = i
break
end
end
-- format arguments for output strings;
args[1] = missing
args.n = missing and 1 or 0
return missing ~= nil, { missing } -- assert result + first missing index as return value
end
-- Register the proper assertion messages
say:set("assertion.array_holes.positive", [[
Expected array to have holes, but none was found.
]])
say:set("assertion.array_holes.negative", [[
Expected array to not have holes, hole found at position: %s
]])
-- Register the assertion, and the modifier
assert:register("assertion", "holes", holes,
"assertion.array_holes.positive",
"assertion.array_holes.negative")
assert:register("modifier", "array", array)

@ -0,0 +1,180 @@
local s = require 'luassert.say'
local astate = require 'luassert.state'
local util = require 'luassert.util'
local unpack = util.unpack
local obj -- the returned module table
local level_mt = {}
-- list of namespaces
local namespace = require 'luassert.namespaces'
local function geterror(assertion_message, failure_message, args)
if util.hastostring(failure_message) then
failure_message = tostring(failure_message)
elseif failure_message ~= nil then
failure_message = astate.format_argument(failure_message)
end
local message = s(assertion_message, obj:format(args))
if message and failure_message then
message = failure_message .. "\n" .. message
end
return message or failure_message
end
local __state_meta = {
__call = function(self, ...)
local keys = util.extract_keys("assertion", self.tokens)
local assertion
for _, key in ipairs(keys) do
assertion = namespace.assertion[key] or assertion
end
if assertion then
for _, key in ipairs(keys) do
if namespace.modifier[key] then
namespace.modifier[key].callback(self)
end
end
local arguments = util.make_arglist(...)
local val, retargs = assertion.callback(self, arguments, util.errorlevel())
if not val == self.mod then
local message = assertion.positive_message
if not self.mod then
message = assertion.negative_message
end
local err = geterror(message, rawget(self,"failure_message"), arguments)
error(err or "assertion failed!", util.errorlevel())
end
if retargs then
return unpack(retargs)
end
return ...
else
local arguments = util.make_arglist(...)
self.tokens = {}
for _, key in ipairs(keys) do
if namespace.modifier[key] then
namespace.modifier[key].callback(self, arguments, util.errorlevel())
end
end
end
return self
end,
__index = function(self, key)
for token in key:lower():gmatch('[^_]+') do
table.insert(self.tokens, token)
end
return self
end
}
obj = {
state = function() return setmetatable({mod=true, tokens={}}, __state_meta) end,
-- registers a function in namespace
register = function(self, nspace, name, callback, positive_message, negative_message)
local lowername = name:lower()
if not namespace[nspace] then
namespace[nspace] = {}
end
namespace[nspace][lowername] = {
callback = callback,
name = lowername,
positive_message=positive_message,
negative_message=negative_message
}
end,
-- unregisters a function in a namespace
unregister = function(self, nspace, name)
local lowername = name:lower()
if not namespace[nspace] then
namespace[nspace] = {}
end
namespace[nspace][lowername] = nil
end,
-- registers a formatter
-- a formatter takes a single argument, and converts it to a string, or returns nil if it cannot format the argument
add_formatter = function(self, callback)
astate.add_formatter(callback)
end,
-- unregisters a formatter
remove_formatter = function(self, fmtr)
astate.remove_formatter(fmtr)
end,
format = function(self, args)
-- args.n specifies the number of arguments in case of 'trailing nil' arguments which get lost
local nofmt = args.nofmt or {} -- arguments in this list should not be formatted
local fmtargs = args.fmtargs or {} -- additional arguments to be passed to formatter
for i = 1, (args.n or #args) do -- cannot use pairs because table might have nils
if not nofmt[i] then
local val = args[i]
local valfmt = astate.format_argument(val, nil, fmtargs[i])
if valfmt == nil then valfmt = tostring(val) end -- no formatter found
args[i] = valfmt
end
end
return args
end,
set_parameter = function(self, name, value)
astate.set_parameter(name, value)
end,
get_parameter = function(self, name)
return astate.get_parameter(name)
end,
add_spy = function(self, spy)
astate.add_spy(spy)
end,
snapshot = function(self)
return astate.snapshot()
end,
level = function(self, level)
return setmetatable({
level = level
}, level_mt)
end,
-- returns the level if a level-value, otherwise nil
get_level = function(self, level)
if getmetatable(level) ~= level_mt then
return nil -- not a valid error-level
end
return level.level
end,
}
local __meta = {
__call = function(self, bool, message, level, ...)
if not bool then
local err_level = (self:get_level(level) or 1) + 1
error(message or "assertion failed!", err_level)
end
return bool , message , level , ...
end,
__index = function(self, key)
return rawget(self, key) or self.state()[key]
end,
}
return setmetatable(obj, __meta)

@ -0,0 +1,328 @@
-- module will not return anything, only register assertions with the main assert engine
-- assertions take 2 parameters;
-- 1) state
-- 2) arguments list. The list has a member 'n' with the argument count to check for trailing nils
-- 3) level The level of the error position relative to the called function
-- returns; boolean; whether assertion passed
local assert = require('luassert.assert')
local astate = require ('luassert.state')
local util = require ('luassert.util')
local s = require('luassert.say')
local function format(val)
return astate.format_argument(val) or tostring(val)
end
local function set_failure_message(state, message)
if message ~= nil then
state.failure_message = message
end
end
local function unique(state, arguments, level)
local list = arguments[1]
local deep
local argcnt = arguments.n
if type(arguments[2]) == "boolean" or (arguments[2] == nil and argcnt > 2) then
deep = arguments[2]
set_failure_message(state, arguments[3])
else
if type(arguments[3]) == "boolean" then
deep = arguments[3]
end
set_failure_message(state, arguments[2])
end
for k,v in pairs(list) do
for k2, v2 in pairs(list) do
if k ~= k2 then
if deep and util.deepcompare(v, v2, true) then
return false
else
if v == v2 then
return false
end
end
end
end
end
return true
end
local function near(state, arguments, level)
local level = (level or 1) + 1
local argcnt = arguments.n
assert(argcnt > 2, s("assertion.internal.argtolittle", { "near", 3, tostring(argcnt) }), level)
local expected = tonumber(arguments[1])
local actual = tonumber(arguments[2])
local tolerance = tonumber(arguments[3])
local numbertype = "number or object convertible to a number"
assert(expected, s("assertion.internal.badargtype", { 1, "near", numbertype, format(arguments[1]) }), level)
assert(actual, s("assertion.internal.badargtype", { 2, "near", numbertype, format(arguments[2]) }), level)
assert(tolerance, s("assertion.internal.badargtype", { 3, "near", numbertype, format(arguments[3]) }), level)
-- switch arguments for proper output message
util.tinsert(arguments, 1, util.tremove(arguments, 2))
arguments[3] = tolerance
arguments.nofmt = arguments.nofmt or {}
arguments.nofmt[3] = true
set_failure_message(state, arguments[4])
return (actual >= expected - tolerance and actual <= expected + tolerance)
end
local function matches(state, arguments, level)
local level = (level or 1) + 1
local argcnt = arguments.n
assert(argcnt > 1, s("assertion.internal.argtolittle", { "matches", 2, tostring(argcnt) }), level)
local pattern = arguments[1]
local actual = nil
if util.hastostring(arguments[2]) or type(arguments[2]) == "number" then
actual = tostring(arguments[2])
end
local err_message
local init_arg_num = 3
for i=3,argcnt,1 do
if arguments[i] and type(arguments[i]) ~= "boolean" and not tonumber(arguments[i]) then
if i == 3 then init_arg_num = init_arg_num + 1 end
err_message = util.tremove(arguments, i)
break
end
end
local init = arguments[3]
local plain = arguments[4]
local stringtype = "string or object convertible to a string"
assert(type(pattern) == "string", s("assertion.internal.badargtype", { 1, "matches", "string", type(arguments[1]) }), level)
assert(actual, s("assertion.internal.badargtype", { 2, "matches", stringtype, format(arguments[2]) }), level)
assert(init == nil or tonumber(init), s("assertion.internal.badargtype", { init_arg_num, "matches", "number", type(arguments[3]) }), level)
-- switch arguments for proper output message
util.tinsert(arguments, 1, util.tremove(arguments, 2))
set_failure_message(state, err_message)
local retargs
local ok
if plain then
ok = (actual:find(pattern, init, plain) ~= nil)
retargs = ok and { pattern } or {}
else
retargs = { actual:match(pattern, init) }
ok = (retargs[1] ~= nil)
end
return ok, retargs
end
local function equals(state, arguments, level)
local level = (level or 1) + 1
local argcnt = arguments.n
assert(argcnt > 1, s("assertion.internal.argtolittle", { "equals", 2, tostring(argcnt) }), level)
local result = arguments[1] == arguments[2]
-- switch arguments for proper output message
util.tinsert(arguments, 1, util.tremove(arguments, 2))
set_failure_message(state, arguments[3])
return result
end
local function same(state, arguments, level)
local level = (level or 1) + 1
local argcnt = arguments.n
assert(argcnt > 1, s("assertion.internal.argtolittle", { "same", 2, tostring(argcnt) }), level)
if type(arguments[1]) == 'table' and type(arguments[2]) == 'table' then
local result, crumbs = util.deepcompare(arguments[1], arguments[2], true)
-- switch arguments for proper output message
util.tinsert(arguments, 1, util.tremove(arguments, 2))
arguments.fmtargs = arguments.fmtargs or {}
arguments.fmtargs[1] = { crumbs = crumbs }
arguments.fmtargs[2] = { crumbs = crumbs }
set_failure_message(state, arguments[3])
return result
end
local result = arguments[1] == arguments[2]
-- switch arguments for proper output message
util.tinsert(arguments, 1, util.tremove(arguments, 2))
set_failure_message(state, arguments[3])
return result
end
local function truthy(state, arguments, level)
set_failure_message(state, arguments[2])
return arguments[1] ~= false and arguments[1] ~= nil
end
local function falsy(state, arguments, level)
return not truthy(state, arguments, level)
end
local function has_error(state, arguments, level)
local level = (level or 1) + 1
local retargs = util.shallowcopy(arguments)
local func = arguments[1]
local err_expected = arguments[2]
local failure_message = arguments[3]
assert(util.callable(func), s("assertion.internal.badargtype", { 1, "error", "function or callable object", type(func) }), level)
local ok, err_actual = pcall(func)
if type(err_actual) == 'string' then
-- remove 'path/to/file:line: ' from string
err_actual = err_actual:gsub('^.-:%d+: ', '', 1)
end
retargs[1] = err_actual
arguments.nofmt = {}
arguments.n = 2
arguments[1] = (ok and '(no error)' or err_actual)
arguments[2] = (err_expected == nil and '(error)' or err_expected)
arguments.nofmt[1] = ok
arguments.nofmt[2] = (err_expected == nil)
set_failure_message(state, failure_message)
if ok or err_expected == nil then
return not ok, retargs
end
if type(err_expected) == 'string' then
-- err_actual must be (convertible to) a string
if util.hastostring(err_actual) then
err_actual = tostring(err_actual)
retargs[1] = err_actual
end
if type(err_actual) == 'string' then
return err_expected == err_actual, retargs
end
elseif type(err_expected) == 'number' then
if type(err_actual) == 'string' then
return tostring(err_expected) == tostring(tonumber(err_actual)), retargs
end
end
return same(state, {err_expected, err_actual, ["n"] = 2}), retargs
end
local function error_matches(state, arguments, level)
local level = (level or 1) + 1
local retargs = util.shallowcopy(arguments)
local argcnt = arguments.n
local func = arguments[1]
local pattern = arguments[2]
assert(argcnt > 1, s("assertion.internal.argtolittle", { "error_matches", 2, tostring(argcnt) }), level)
assert(util.callable(func), s("assertion.internal.badargtype", { 1, "error_matches", "function or callable object", type(func) }), level)
assert(pattern == nil or type(pattern) == "string", s("assertion.internal.badargtype", { 2, "error", "string", type(pattern) }), level)
local failure_message
local init_arg_num = 3
for i=3,argcnt,1 do
if arguments[i] and type(arguments[i]) ~= "boolean" and not tonumber(arguments[i]) then
if i == 3 then init_arg_num = init_arg_num + 1 end
failure_message = util.tremove(arguments, i)
break
end
end
local init = arguments[3]
local plain = arguments[4]
assert(init == nil or tonumber(init), s("assertion.internal.badargtype", { init_arg_num, "matches", "number", type(arguments[3]) }), level)
local ok, err_actual = pcall(func)
if type(err_actual) == 'string' then
-- remove 'path/to/file:line: ' from string
err_actual = err_actual:gsub('^.-:%d+: ', '', 1)
end
retargs[1] = err_actual
arguments.nofmt = {}
arguments.n = 2
arguments[1] = (ok and '(no error)' or err_actual)
arguments[2] = pattern
arguments.nofmt[1] = ok
arguments.nofmt[2] = false
set_failure_message(state, failure_message)
if ok then return not ok, retargs end
if err_actual == nil and pattern == nil then
return true, {}
end
-- err_actual must be (convertible to) a string
if util.hastostring(err_actual) then
err_actual = tostring(err_actual)
retargs[1] = err_actual
end
if type(err_actual) == 'string' then
local ok
local retargs_ok
if plain then
retargs_ok = { pattern }
ok = (err_actual:find(pattern, init, plain) ~= nil)
else
retargs_ok = { err_actual:match(pattern, init) }
ok = (retargs_ok[1] ~= nil)
end
if ok then retargs = retargs_ok end
return ok, retargs
end
return false, retargs
end
local function is_true(state, arguments, level)
util.tinsert(arguments, 2, true)
set_failure_message(state, arguments[3])
return arguments[1] == arguments[2]
end
local function is_false(state, arguments, level)
util.tinsert(arguments, 2, false)
set_failure_message(state, arguments[3])
return arguments[1] == arguments[2]
end
local function is_type(state, arguments, level, etype)
util.tinsert(arguments, 2, "type " .. etype)
arguments.nofmt = arguments.nofmt or {}
arguments.nofmt[2] = true
set_failure_message(state, arguments[3])
return arguments.n > 1 and type(arguments[1]) == etype
end
local function returned_arguments(state, arguments, level)
arguments[1] = tostring(arguments[1])
arguments[2] = tostring(arguments.n - 1)
arguments.nofmt = arguments.nofmt or {}
arguments.nofmt[1] = true
arguments.nofmt[2] = true
if arguments.n < 2 then arguments.n = 2 end
return arguments[1] == arguments[2]
end
local function set_message(state, arguments, level)
state.failure_message = arguments[1]
end
local function is_boolean(state, arguments, level) return is_type(state, arguments, level, "boolean") end
local function is_number(state, arguments, level) return is_type(state, arguments, level, "number") end
local function is_string(state, arguments, level) return is_type(state, arguments, level, "string") end
local function is_table(state, arguments, level) return is_type(state, arguments, level, "table") end
local function is_nil(state, arguments, level) return is_type(state, arguments, level, "nil") end
local function is_userdata(state, arguments, level) return is_type(state, arguments, level, "userdata") end
local function is_function(state, arguments, level) return is_type(state, arguments, level, "function") end
local function is_thread(state, arguments, level) return is_type(state, arguments, level, "thread") end
assert:register("modifier", "message", set_message)
assert:register("assertion", "true", is_true, "assertion.same.positive", "assertion.same.negative")
assert:register("assertion", "false", is_false, "assertion.same.positive", "assertion.same.negative")
assert:register("assertion", "boolean", is_boolean, "assertion.same.positive", "assertion.same.negative")
assert:register("assertion", "number", is_number, "assertion.same.positive", "assertion.same.negative")
assert:register("assertion", "string", is_string, "assertion.same.positive", "assertion.same.negative")
assert:register("assertion", "table", is_table, "assertion.same.positive", "assertion.same.negative")
assert:register("assertion", "nil", is_nil, "assertion.same.positive", "assertion.same.negative")
assert:register("assertion", "userdata", is_userdata, "assertion.same.positive", "assertion.same.negative")
assert:register("assertion", "function", is_function, "assertion.same.positive", "assertion.same.negative")
assert:register("assertion", "thread", is_thread, "assertion.same.positive", "assertion.same.negative")
assert:register("assertion", "returned_arguments", returned_arguments, "assertion.returned_arguments.positive", "assertion.returned_arguments.negative")
assert:register("assertion", "same", same, "assertion.same.positive", "assertion.same.negative")
assert:register("assertion", "matches", matches, "assertion.matches.positive", "assertion.matches.negative")
assert:register("assertion", "match", matches, "assertion.matches.positive", "assertion.matches.negative")
assert:register("assertion", "near", near, "assertion.near.positive", "assertion.near.negative")
assert:register("assertion", "equals", equals, "assertion.equals.positive", "assertion.equals.negative")
assert:register("assertion", "equal", equals, "assertion.equals.positive", "assertion.equals.negative")
assert:register("assertion", "unique", unique, "assertion.unique.positive", "assertion.unique.negative")
assert:register("assertion", "error", has_error, "assertion.error.positive", "assertion.error.negative")
assert:register("assertion", "errors", has_error, "assertion.error.positive", "assertion.error.negative")
assert:register("assertion", "error_matches", error_matches, "assertion.error.positive", "assertion.error.negative")
assert:register("assertion", "error_match", error_matches, "assertion.error.positive", "assertion.error.negative")
assert:register("assertion", "matches_error", error_matches, "assertion.error.positive", "assertion.error.negative")
assert:register("assertion", "match_error", error_matches, "assertion.error.positive", "assertion.error.negative")
assert:register("assertion", "truthy", truthy, "assertion.truthy.positive", "assertion.truthy.negative")
assert:register("assertion", "falsy", falsy, "assertion.falsy.positive", "assertion.falsy.negative")

@ -0,0 +1,9 @@
-- no longer needed, only for backward compatibility
local unpack = require ("luassert.util").unpack
return {
unpack = function(...)
print(debug.traceback("WARN: calling deprecated function 'luassert.compatibility.unpack' use 'luassert.util.unpack' instead"))
return unpack(...)
end
}

@ -0,0 +1,33 @@
local format = function(str)
if type(str) ~= "string" then
return nil
end
local result = "Binary string length; " .. tostring(#str) .. " bytes\n"
local i = 1
local hex = ""
local chr = ""
while i <= #str do
local byte = str:byte(i)
hex = string.format("%s%2x ", hex, byte)
if byte < 32 then
byte = string.byte(".")
end
chr = chr .. string.char(byte)
if math.floor(i / 16) == i / 16 or i == #str then
-- reached end of line
hex = hex .. string.rep(" ", 16 * 3 - #hex)
chr = chr .. string.rep(" ", 16 - #chr)
result = result .. hex:sub(1, 8 * 3) .. " " .. hex:sub(8 * 3 + 1, -1) .. " " .. chr:sub(1, 8) .. " " ..
chr:sub(9, -1) .. "\n"
hex = ""
chr = ""
end
i = i + 1
end
return result
end
return format

@ -0,0 +1,258 @@
-- module will not return anything, only register formatters with the main assert engine
local assert = require('luassert.assert')
local match = require('luassert.match')
local util = require('luassert.util')
local colors = setmetatable({
none = function(c)
return c
end,
}, {
__index = function(self, key)
local ok, term = pcall(require, 'term')
local isatty = io.type(io.stdout) == 'file' and ok and term.isatty(io.stdout)
if not ok or not isatty or not term.colors then
return function(c)
return c
end
end
return function(c)
for token in key:gmatch("[^%.]+") do
c = term.colors[token](c)
end
return c
end
end,
})
local function fmt_string(arg)
if type(arg) == "string" then
return string.format("(string) '%s'", arg)
end
end
-- A version of tostring which formats numbers more precisely.
local function tostr(arg)
if type(arg) ~= "number" then
return tostring(arg)
end
if arg ~= arg then
return "NaN"
elseif arg == 1 / 0 then
return "Inf"
elseif arg == -1 / 0 then
return "-Inf"
end
local str = string.format("%.20g", arg)
if math.type and math.type(arg) == "float" and not str:find("[%.,]") then
-- Number is a float but looks like an integer.
-- Insert ".0" after first run of digits.
str = str:gsub("%d+", "%0.0", 1)
end
return str
end
local function fmt_number(arg)
if type(arg) == "number" then
return string.format("(number) %s", tostr(arg))
end
end
local function fmt_boolean(arg)
if type(arg) == "boolean" then
return string.format("(boolean) %s", tostring(arg))
end
end
local function fmt_nil(arg)
if type(arg) == "nil" then
return "(nil)"
end
end
local type_priorities = {
number = 1,
boolean = 2,
string = 3,
table = 4,
["function"] = 5,
userdata = 6,
thread = 7,
}
local function is_in_array_part(key, length)
return type(key) == "number" and 1 <= key and key <= length and math.floor(key) == key
end
local function get_sorted_keys(t)
local keys = {}
local nkeys = 0
for key in pairs(t) do
nkeys = nkeys + 1
keys[nkeys] = key
end
local length = #t
local function key_comparator(key1, key2)
local type1, type2 = type(key1), type(key2)
local priority1 = is_in_array_part(key1, length) and 0 or type_priorities[type1] or 8
local priority2 = is_in_array_part(key2, length) and 0 or type_priorities[type2] or 8
if priority1 == priority2 then
if type1 == "string" or type1 == "number" then
return key1 < key2
elseif type1 == "boolean" then
return key1 -- put true before false
end
else
return priority1 < priority2
end
end
table.sort(keys, key_comparator)
return keys, nkeys
end
local function fmt_table(arg, fmtargs)
if type(arg) ~= "table" then
return
end
local tmax = assert:get_parameter("TableFormatLevel")
local showrec = assert:get_parameter("TableFormatShowRecursion")
local errchar = assert:get_parameter("TableErrorHighlightCharacter") or ""
local errcolor = assert:get_parameter("TableErrorHighlightColor") or "none"
local crumbs = fmtargs and fmtargs.crumbs or {}
local cache = {}
local type_desc
if getmetatable(arg) == nil then
type_desc = "(" .. tostring(arg) .. ") "
elseif not pcall(setmetatable, arg, getmetatable(arg)) then
-- cannot set same metatable, so it is protected, skip id
type_desc = "(table) "
else
-- unprotected metatable, temporary remove the mt
local mt = getmetatable(arg)
setmetatable(arg, nil)
type_desc = "(" .. tostring(arg) .. ") "
setmetatable(arg, mt)
end
local function ft(t, l, with_crumbs)
if showrec and cache[t] and cache[t] > 0 then
return "{ ... recursive }"
end
if next(t) == nil then
return "{ }"
end
if l > tmax and tmax >= 0 then
return "{ ... more }"
end
local result = "{"
local keys, nkeys = get_sorted_keys(t)
cache[t] = (cache[t] or 0) + 1
local crumb = crumbs[#crumbs - l + 1]
for i = 1, nkeys do
local k = keys[i]
local v = t[k]
local use_crumbs = with_crumbs and k == crumb
if type(v) == "table" then
v = ft(v, l + 1, use_crumbs)
elseif type(v) == "string" then
v = "'" .. v .. "'"
end
local ch = use_crumbs and errchar or ""
local indent = string.rep(" ", l * 2 - ch:len())
local mark = (ch:len() == 0 and "" or colors[errcolor](ch))
result = result .. string.format("\n%s%s[%s] = %s", indent, mark, tostr(k), tostr(v))
end
cache[t] = cache[t] - 1
return result .. " }"
end
return type_desc .. ft(arg, 1, true)
end
local function fmt_function(arg)
if type(arg) == "function" then
local debug_info = debug.getinfo(arg)
return string.format("%s @ line %s in %s", tostring(arg), tostring(debug_info.linedefined),
tostring(debug_info.source))
end
end
local function fmt_userdata(arg)
if type(arg) == "userdata" then
return string.format("(userdata) '%s'", tostring(arg))
end
end
local function fmt_thread(arg)
if type(arg) == "thread" then
return string.format("(thread) '%s'", tostring(arg))
end
end
local function fmt_matcher(arg)
if not match.is_matcher(arg) then
return
end
local not_inverted = {
[true] = "is.",
[false] = "no.",
}
local args = {}
for idx = 1, arg.arguments.n do
table.insert(args, assert:format({
arg.arguments[idx],
n = 1,
})[1])
end
return string.format("(matcher) %s%s(%s)", not_inverted[arg.mod], tostring(arg.name), table.concat(args, ", "))
end
local function fmt_arglist(arglist)
if not util.is_arglist(arglist) then
return
end
local formatted_vals = {}
for idx = 1, arglist.n do
table.insert(formatted_vals, assert:format({
arglist[idx],
n = 1,
})[1])
end
return "(values list) (" .. table.concat(formatted_vals, ", ") .. ")"
end
assert:add_formatter(fmt_string)
assert:add_formatter(fmt_number)
assert:add_formatter(fmt_boolean)
assert:add_formatter(fmt_nil)
assert:add_formatter(fmt_table)
assert:add_formatter(fmt_function)
assert:add_formatter(fmt_userdata)
assert:add_formatter(fmt_thread)
assert:add_formatter(fmt_matcher)
assert:add_formatter(fmt_arglist)
-- Set default table display depth for table formatter
assert:set_parameter("TableFormatLevel", 3)
assert:set_parameter("TableFormatShowRecursion", false)
assert:set_parameter("TableErrorHighlightCharacter", "*")
assert:set_parameter("TableErrorHighlightColor", "none")

@ -0,0 +1,18 @@
local assert = require('luassert.assert')
assert._COPYRIGHT = "Copyright (c) 2018 Olivine Labs, LLC."
assert._DESCRIPTION =
"Extends Lua's built-in assertions to provide additional tests and the ability to create your own."
assert._VERSION = "Luassert 1.8.0"
-- load basic asserts
require('luassert.assertions')
require('luassert.modifiers')
require('luassert.array')
require('luassert.matchers')
require('luassert.formatters')
-- load default language
require('luassert.languages.en')
return assert

@ -0,0 +1,52 @@
local s = require('luassert.say')
s:set_namespace('en')
s:set("assertion.same.positive", "Expected objects to be the same.\nPassed in:\n%s\nExpected:\n%s")
s:set("assertion.same.negative", "Expected objects to not be the same.\nPassed in:\n%s\nDid not expect:\n%s")
s:set("assertion.equals.positive", "Expected objects to be equal.\nPassed in:\n%s\nExpected:\n%s")
s:set("assertion.equals.negative", "Expected objects to not be equal.\nPassed in:\n%s\nDid not expect:\n%s")
s:set("assertion.near.positive", "Expected values to be near.\nPassed in:\n%s\nExpected:\n%s +/- %s")
s:set("assertion.near.negative", "Expected values to not be near.\nPassed in:\n%s\nDid not expect:\n%s +/- %s")
s:set("assertion.matches.positive", "Expected strings to match.\nPassed in:\n%s\nExpected:\n%s")
s:set("assertion.matches.negative", "Expected strings not to match.\nPassed in:\n%s\nDid not expect:\n%s")
s:set("assertion.unique.positive", "Expected object to be unique:\n%s")
s:set("assertion.unique.negative", "Expected object to not be unique:\n%s")
s:set("assertion.error.positive", "Expected a different error.\nCaught:\n%s\nExpected:\n%s")
s:set("assertion.error.negative", "Expected no error, but caught:\n%s")
s:set("assertion.truthy.positive", "Expected to be truthy, but value was:\n%s")
s:set("assertion.truthy.negative", "Expected to not be truthy, but value was:\n%s")
s:set("assertion.falsy.positive", "Expected to be falsy, but value was:\n%s")
s:set("assertion.falsy.negative", "Expected to not be falsy, but value was:\n%s")
s:set("assertion.called.positive", "Expected to be called %s time(s), but was called %s time(s)")
s:set("assertion.called.negative", "Expected not to be called exactly %s time(s), but it was.")
s:set("assertion.called_at_least.positive", "Expected to be called at least %s time(s), but was called %s time(s)")
s:set("assertion.called_at_most.positive", "Expected to be called at most %s time(s), but was called %s time(s)")
s:set("assertion.called_more_than.positive", "Expected to be called more than %s time(s), but was called %s time(s)")
s:set("assertion.called_less_than.positive", "Expected to be called less than %s time(s), but was called %s time(s)")
s:set("assertion.called_with.positive",
"Function was never called with matching arguments.\nCalled with (last call if any):\n%s\nExpected:\n%s")
s:set("assertion.called_with.negative",
"Function was called with matching arguments at least once.\nCalled with (last matching call):\n%s\nDid not expect:\n%s")
s:set("assertion.returned_with.positive",
"Function never returned matching arguments.\nReturned (last call if any):\n%s\nExpected:\n%s")
s:set("assertion.returned_with.negative",
"Function returned matching arguments at least once.\nReturned (last matching call):\n%s\nDid not expect:\n%s")
s:set("assertion.returned_arguments.positive", "Expected to be called with %s argument(s), but was called with %s")
s:set("assertion.returned_arguments.negative", "Expected not to be called with %s argument(s), but was called with %s")
-- errors
s:set("assertion.internal.argtolittle", "the '%s' function requires a minimum of %s arguments, got: %s")
s:set("assertion.internal.badargtype", "bad argument #%s to '%s' (%s expected, got %s)")

@ -0,0 +1,31 @@
local s = require('luassert.say')
s:set_namespace('zh')
s:set("assertion.same.positive", "希望对象应该相同.\n实际值:\n%s\n希望值:\n%s")
s:set("assertion.same.negative", "希望对象应该不相同.\n实际值:\n%s\n不希望与:\n%s\n相同")
s:set("assertion.equals.positive", "希望对象应该相等.\n实际值:\n%s\n希望值:\n%s")
s:set("assertion.equals.negative", "希望对象应该不相等.\n实际值:\n%s\n不希望等于:\n%s")
s:set("assertion.unique.positive", "希望对象是唯一的:\n%s")
s:set("assertion.unique.negative", "希望对象不是唯一的:\n%s")
s:set("assertion.error.positive", "希望有错误被抛出.")
s:set("assertion.error.negative", "希望没有错误被抛出.\n%s")
s:set("assertion.truthy.positive", "希望结果为真,但是实际为:\n%s")
s:set("assertion.truthy.negative", "希望结果不为真,但是实际为:\n%s")
s:set("assertion.falsy.positive", "希望结果为假,但是实际为:\n%s")
s:set("assertion.falsy.negative", "希望结果不为假,但是实际为:\n%s")
s:set("assertion.called.positive", "希望被调用%s次, 但实际被调用了%s次")
s:set("assertion.called.negative", "不希望正好被调用%s次, 但是正好被调用了那么多次.")
s:set("assertion.called_with.positive", "希望没有参数的调用函数")
s:set("assertion.called_with.negative", "希望有参数的调用函数")
-- errors
s:set("assertion.internal.argtolittle", "函数'%s'需要最少%s个参数, 实际有%s个参数\n")
s:set("assertion.internal.badargtype", "bad argument #%s: 函数'%s'需要一个%s作为参数, 实际为: %s\n")

@ -0,0 +1,79 @@
local namespace = require 'luassert.namespaces'
local util = require 'luassert.util'
local matcher_mt = {
__call = function(self, value)
return self.callback(value) == self.mod
end,
}
local state_mt = {
__call = function(self, ...)
local keys = util.extract_keys("matcher", self.tokens)
self.tokens = {}
local matcher
for _, key in ipairs(keys) do
matcher = namespace.matcher[key] or matcher
end
if matcher then
for _, key in ipairs(keys) do
if namespace.modifier[key] then
namespace.modifier[key].callback(self)
end
end
local arguments = util.make_arglist(...)
local matches = matcher.callback(self, arguments, util.errorlevel())
return setmetatable({
name = matcher.name,
mod = self.mod,
callback = matches,
arguments = arguments,
}, matcher_mt)
else
local arguments = util.make_arglist(...)
for _, key in ipairs(keys) do
if namespace.modifier[key] then
namespace.modifier[key].callback(self, arguments, util.errorlevel())
end
end
end
return self
end,
__index = function(self, key)
for token in key:lower():gmatch('[^_]+') do
table.insert(self.tokens, token)
end
return self
end
}
local match = {
_ = setmetatable({mod=true, callback=function() return true end}, matcher_mt),
state = function() return setmetatable({mod=true, tokens={}}, state_mt) end,
is_matcher = function(object)
return type(object) == "table" and getmetatable(object) == matcher_mt
end,
is_ref_matcher = function(object)
local ismatcher = (type(object) == "table" and getmetatable(object) == matcher_mt)
return ismatcher and object.name == "ref"
end,
}
local mt = {
__index = function(self, key)
return rawget(self, key) or self.state()[key]
end,
}
return setmetatable(match, mt)

@ -0,0 +1,64 @@
local assert = require('luassert.assert')
local match = require('luassert.match')
local s = require('luassert.say')
local function none(state, arguments, level)
local level = (level or 1) + 1
local argcnt = arguments.n
assert(argcnt > 0, s("assertion.internal.argtolittle", {"none", 1, tostring(argcnt)}), level)
for i = 1, argcnt do
assert(match.is_matcher(arguments[i]),
s("assertion.internal.badargtype", {1, "none", "matcher", type(arguments[i])}), level)
end
return function(value)
for _, matcher in ipairs(arguments) do
if matcher(value) then
return false
end
end
return true
end
end
local function any(state, arguments, level)
local level = (level or 1) + 1
local argcnt = arguments.n
assert(argcnt > 0, s("assertion.internal.argtolittle", {"any", 1, tostring(argcnt)}), level)
for i = 1, argcnt do
assert(match.is_matcher(arguments[i]),
s("assertion.internal.badargtype", {1, "any", "matcher", type(arguments[i])}), level)
end
return function(value)
for _, matcher in ipairs(arguments) do
if matcher(value) then
return true
end
end
return false
end
end
local function all(state, arguments, level)
local level = (level or 1) + 1
local argcnt = arguments.n
assert(argcnt > 0, s("assertion.internal.argtolittle", {"all", 1, tostring(argcnt)}), level)
for i = 1, argcnt do
assert(match.is_matcher(arguments[i]),
s("assertion.internal.badargtype", {1, "all", "matcher", type(arguments[i])}), level)
end
return function(value)
for _, matcher in ipairs(arguments) do
if not matcher(value) then
return false
end
end
return true
end
end
assert:register("matcher", "none_of", none)
assert:register("matcher", "any_of", any)
assert:register("matcher", "all_of", all)

@ -0,0 +1,173 @@
-- module will return the list of matchers, and registers matchers with the main assert engine
-- matchers take 1 parameters;
-- 1) state
-- 2) arguments list. The list has a member 'n' with the argument count to check for trailing nils
-- 3) level The level of the error position relative to the called function
-- returns; function (or callable object); a function that, given an argument, returns a boolean
local assert = require('luassert.assert')
local astate = require('luassert.state')
local util = require('luassert.util')
local s = require('luassert.say')
local function format(val)
return astate.format_argument(val) or tostring(val)
end
local function unique(state, arguments, level)
local deep = arguments[1]
return function(value)
local list = value
for k,v in pairs(list) do
for k2, v2 in pairs(list) do
if k ~= k2 then
if deep and util.deepcompare(v, v2, true) then
return false
else
if v == v2 then
return false
end
end
end
end
end
return true
end
end
local function near(state, arguments, level)
local level = (level or 1) + 1
local argcnt = arguments.n
assert(argcnt > 1, s("assertion.internal.argtolittle", { "near", 2, tostring(argcnt) }), level)
local expected = tonumber(arguments[1])
local tolerance = tonumber(arguments[2])
local numbertype = "number or object convertible to a number"
assert(expected, s("assertion.internal.badargtype", { 1, "near", numbertype, format(arguments[1]) }), level)
assert(tolerance, s("assertion.internal.badargtype", { 2, "near", numbertype, format(arguments[2]) }), level)
return function(value)
local actual = tonumber(value)
if not actual then return false end
return (actual >= expected - tolerance and actual <= expected + tolerance)
end
end
local function matches(state, arguments, level)
local level = (level or 1) + 1
local argcnt = arguments.n
assert(argcnt > 0, s("assertion.internal.argtolittle", { "matches", 1, tostring(argcnt) }), level)
local pattern = arguments[1]
local init = arguments[2]
local plain = arguments[3]
assert(type(pattern) == "string", s("assertion.internal.badargtype", { 1, "matches", "string", type(arguments[1]) }), level)
assert(init == nil or tonumber(init), s("assertion.internal.badargtype", { 2, "matches", "number", type(arguments[2]) }), level)
return function(value)
local actualtype = type(value)
local actual = nil
if actualtype == "string" or actualtype == "number" or
actualtype == "table" and (getmetatable(value) or {}).__tostring then
actual = tostring(value)
end
if not actual then return false end
return (actual:find(pattern, init, plain) ~= nil)
end
end
local function equals(state, arguments, level)
local level = (level or 1) + 1
local argcnt = arguments.n
assert(argcnt > 0, s("assertion.internal.argtolittle", { "equals", 1, tostring(argcnt) }), level)
return function(value)
return value == arguments[1]
end
end
local function same(state, arguments, level)
local level = (level or 1) + 1
local argcnt = arguments.n
assert(argcnt > 0, s("assertion.internal.argtolittle", { "same", 1, tostring(argcnt) }), level)
return function(value)
if type(value) == 'table' and type(arguments[1]) == 'table' then
local result = util.deepcompare(value, arguments[1], true)
return result
end
return value == arguments[1]
end
end
local function ref(state, arguments, level)
local level = (level or 1) + 1
local argcnt = arguments.n
local argtype = type(arguments[1])
local isobject = (argtype == "table" or argtype == "function" or argtype == "thread" or argtype == "userdata")
assert(argcnt > 0, s("assertion.internal.argtolittle", { "ref", 1, tostring(argcnt) }), level)
assert(isobject, s("assertion.internal.badargtype", { 1, "ref", "object", argtype }), level)
return function(value)
return value == arguments[1]
end
end
local function is_true(state, arguments, level)
return function(value)
return value == true
end
end
local function is_false(state, arguments, level)
return function(value)
return value == false
end
end
local function truthy(state, arguments, level)
return function(value)
return value ~= false and value ~= nil
end
end
local function falsy(state, arguments, level)
local is_truthy = truthy(state, arguments, level)
return function(value)
return not is_truthy(value)
end
end
local function is_type(state, arguments, level, etype)
return function(value)
return type(value) == etype
end
end
local function is_nil(state, arguments, level) return is_type(state, arguments, level, "nil") end
local function is_boolean(state, arguments, level) return is_type(state, arguments, level, "boolean") end
local function is_number(state, arguments, level) return is_type(state, arguments, level, "number") end
local function is_string(state, arguments, level) return is_type(state, arguments, level, "string") end
local function is_table(state, arguments, level) return is_type(state, arguments, level, "table") end
local function is_function(state, arguments, level) return is_type(state, arguments, level, "function") end
local function is_userdata(state, arguments, level) return is_type(state, arguments, level, "userdata") end
local function is_thread(state, arguments, level) return is_type(state, arguments, level, "thread") end
assert:register("matcher", "true", is_true)
assert:register("matcher", "false", is_false)
assert:register("matcher", "nil", is_nil)
assert:register("matcher", "boolean", is_boolean)
assert:register("matcher", "number", is_number)
assert:register("matcher", "string", is_string)
assert:register("matcher", "table", is_table)
assert:register("matcher", "function", is_function)
assert:register("matcher", "userdata", is_userdata)
assert:register("matcher", "thread", is_thread)
assert:register("matcher", "ref", ref)
assert:register("matcher", "same", same)
assert:register("matcher", "matches", matches)
assert:register("matcher", "match", matches)
assert:register("matcher", "near", near)
assert:register("matcher", "equals", equals)
assert:register("matcher", "equal", equals)
assert:register("matcher", "unique", unique)
assert:register("matcher", "truthy", truthy)
assert:register("matcher", "falsy", falsy)

@ -0,0 +1,3 @@
-- load basic machers
require('luassert.matchers.core')
require('luassert.matchers.composite')

@ -0,0 +1,65 @@
-- module will return a mock module table, and will not register any assertions
local spy = require 'luassert.spy'
local stub = require 'luassert.stub'
local function mock_apply(object, action)
if type(object) ~= "table" then
return
end
if spy.is_spy(object) then
return object[action](object)
end
for k, v in pairs(object) do
mock_apply(v, action)
end
return object
end
local mock
mock = {
new = function(object, dostub, func, self, key)
local visited = {}
local function do_mock(object, self, key)
local mock_handlers = {
["table"] = function()
if spy.is_spy(object) or visited[object] then
return
end
visited[object] = true
for k, v in pairs(object) do
object[k] = do_mock(v, object, k)
end
return object
end,
["function"] = function()
if dostub then
return stub(self, key, func)
elseif self == nil then
return spy.new(object)
else
return spy.on(self, key)
end
end,
}
local handler = mock_handlers[type(object)]
return handler and handler() or object
end
return do_mock(object, self, key)
end,
clear = function(object)
return mock_apply(object, "clear")
end,
revert = function(object)
return mock_apply(object, "revert")
end,
}
return setmetatable(mock, {
__call = function(self, ...)
-- mock originally was a function only. Now that it is a module table
-- the __call method is required for backward compatibility
return mock.new(...)
end,
})

@ -0,0 +1,19 @@
-- module will not return anything, only register assertions/modifiers with the main assert engine
local assert = require('luassert.assert')
local function is(state)
return state
end
local function is_not(state)
state.mod = not state.mod
return state
end
assert:register("modifier", "is", is)
assert:register("modifier", "are", is)
assert:register("modifier", "was", is)
assert:register("modifier", "has", is)
assert:register("modifier", "does", is)
assert:register("modifier", "not", is_not)
assert:register("modifier", "no", is_not)

@ -0,0 +1,2 @@
-- stores the list of namespaces
return {}

@ -0,0 +1,64 @@
local unpack = table.unpack or unpack
local registry = {}
local current_namespace
local fallback_namespace
local s = {
_COPYRIGHT = "Copyright (c) 2012 Olivine Labs, LLC.",
_DESCRIPTION = "A simple string key/value store for i18n or any other case where you want namespaced strings.",
_VERSION = "Say 1.3",
set_namespace = function(self, namespace)
current_namespace = namespace
if not registry[current_namespace] then
registry[current_namespace] = {}
end
end,
set_fallback = function(self, namespace)
fallback_namespace = namespace
if not registry[fallback_namespace] then
registry[fallback_namespace] = {}
end
end,
set = function(self, key, value)
registry[current_namespace][key] = value
end,
}
local __meta = {
__call = function(self, key, vars)
if vars ~= nil and type(vars) ~= "table" then
error(("expected parameter table to be a table, got '%s'"):format(type(vars)), 2)
end
vars = vars or {}
local str = registry[current_namespace][key] or registry[fallback_namespace][key]
if str == nil then
return nil
end
str = tostring(str)
local strings = {}
for i = 1, vars.n or #vars do
table.insert(strings, tostring(vars[i]))
end
return #strings > 0 and str:format(unpack(strings)) or str
end,
__index = function(self, key)
return registry[key]
end,
}
s:set_fallback('en')
s:set_namespace('en')
s._registry = registry
return setmetatable(s, __meta)

@ -0,0 +1,215 @@
-- module will return spy table, and register its assertions with the main assert engine
local assert = require('luassert.assert')
local util = require('luassert.util')
-- Spy metatable
local spy_mt = {
__call = function(self, ...)
local arguments = util.make_arglist(...)
table.insert(self.calls, util.copyargs(arguments))
local function get_returns(...)
local returnvals = util.make_arglist(...)
table.insert(self.returnvals, util.copyargs(returnvals))
return ...
end
return get_returns(self.callback(...))
end,
}
local spy -- must make local before defining table, because table contents refers to the table (recursion)
spy = {
new = function(callback)
callback = callback or function()
end
if not util.callable(callback) then
error("Cannot spy on type '" .. type(callback) .. "', only on functions or callable elements",
util.errorlevel())
end
local s = setmetatable({
calls = {},
returnvals = {},
callback = callback,
target_table = nil, -- these will be set when using 'spy.on'
target_key = nil,
revert = function(self)
if not self.reverted then
if self.target_table and self.target_key then
self.target_table[self.target_key] = self.callback
end
self.reverted = true
end
return self.callback
end,
clear = function(self)
self.calls = {}
self.returnvals = {}
return self
end,
called = function(self, times, compare)
if times or compare then
local compare = compare or function(count, expected)
return count == expected
end
return compare(#self.calls, times), #self.calls
end
return (#self.calls > 0), #self.calls
end,
called_with = function(self, args)
local last_arglist = nil
if #self.calls > 0 then
last_arglist = self.calls[#self.calls].vals
end
local matching_arglists = util.matchargs(self.calls, args)
if matching_arglists ~= nil then
return true, matching_arglists.vals
end
return false, last_arglist
end,
returned_with = function(self, args)
local last_returnvallist = nil
if #self.returnvals > 0 then
last_returnvallist = self.returnvals[#self.returnvals].vals
end
local matching_returnvallists = util.matchargs(self.returnvals, args)
if matching_returnvallists ~= nil then
return true, matching_returnvallists.vals
end
return false, last_returnvallist
end,
}, spy_mt)
assert:add_spy(s) -- register with the current state
return s
end,
is_spy = function(object)
return type(object) == "table" and getmetatable(object) == spy_mt
end,
on = function(target_table, target_key)
local s = spy.new(target_table[target_key])
target_table[target_key] = s
-- store original data
s.target_table = target_table
s.target_key = target_key
return s
end,
}
local function set_spy(state, arguments, level)
state.payload = arguments[1]
if arguments[2] ~= nil then
state.failure_message = arguments[2]
end
end
local function returned_with(state, arguments, level)
local level = (level or 1) + 1
local payload = rawget(state, "payload")
if payload and payload.returned_with then
local assertion_holds, matching_or_last_returnvallist = state.payload:returned_with(arguments)
local expected_returnvallist = util.shallowcopy(arguments)
util.cleararglist(arguments)
util.tinsert(arguments, 1, matching_or_last_returnvallist)
util.tinsert(arguments, 2, expected_returnvallist)
return assertion_holds
else
error("'returned_with' must be chained after 'spy(aspy)'", level)
end
end
local function called_with(state, arguments, level)
local level = (level or 1) + 1
local payload = rawget(state, "payload")
if payload and payload.called_with then
local assertion_holds, matching_or_last_arglist = state.payload:called_with(arguments)
local expected_arglist = util.shallowcopy(arguments)
util.cleararglist(arguments)
util.tinsert(arguments, 1, matching_or_last_arglist)
util.tinsert(arguments, 2, expected_arglist)
return assertion_holds
else
error("'called_with' must be chained after 'spy(aspy)'", level)
end
end
local function called(state, arguments, level, compare)
local level = (level or 1) + 1
local num_times = arguments[1]
if not num_times and not state.mod then
state.mod = true
num_times = 0
end
local payload = rawget(state, "payload")
if payload and type(payload) == "table" and payload.called then
local result, count = state.payload:called(num_times, compare)
arguments[1] = tostring(num_times or ">0")
util.tinsert(arguments, 2, tostring(count))
arguments.nofmt = arguments.nofmt or {}
arguments.nofmt[1] = true
arguments.nofmt[2] = true
return result
elseif payload and type(payload) == "function" then
error(
"When calling 'spy(aspy)', 'aspy' must not be the original function, but the spy function replacing the original",
level)
else
error("'called' must be chained after 'spy(aspy)'", level)
end
end
local function called_at_least(state, arguments, level)
local level = (level or 1) + 1
return called(state, arguments, level, function(count, expected)
return count >= expected
end)
end
local function called_at_most(state, arguments, level)
local level = (level or 1) + 1
return called(state, arguments, level, function(count, expected)
return count <= expected
end)
end
local function called_more_than(state, arguments, level)
local level = (level or 1) + 1
return called(state, arguments, level, function(count, expected)
return count > expected
end)
end
local function called_less_than(state, arguments, level)
local level = (level or 1) + 1
return called(state, arguments, level, function(count, expected)
return count < expected
end)
end
assert:register("modifier", "spy", set_spy)
assert:register("assertion", "returned_with", returned_with, "assertion.returned_with.positive",
"assertion.returned_with.negative")
assert:register("assertion", "called_with", called_with, "assertion.called_with.positive",
"assertion.called_with.negative")
assert:register("assertion", "called", called, "assertion.called.positive", "assertion.called.negative")
assert:register("assertion", "called_at_least", called_at_least, "assertion.called_at_least.positive",
"assertion.called_less_than.positive")
assert:register("assertion", "called_at_most", called_at_most, "assertion.called_at_most.positive",
"assertion.called_more_than.positive")
assert:register("assertion", "called_more_than", called_more_than, "assertion.called_more_than.positive",
"assertion.called_at_most.positive")
assert:register("assertion", "called_less_than", called_less_than, "assertion.called_less_than.positive",
"assertion.called_at_least.positive")
return setmetatable(spy, {
__call = function(self, ...)
return spy.new(...)
end,
})

@ -0,0 +1,134 @@
-- maintains a state of the assert engine in a linked-list fashion
-- records; formatters, parameters, spies and stubs
local state_mt = {
__call = function(self)
self:revert()
end,
}
local spies_mt = {
__mode = "kv",
}
local nilvalue = {} -- unique ID to refer to nil values for parameters
-- will hold the current state
local current
-- exported module table
local state = {}
------------------------------------------------------
-- Reverts to a (specific) snapshot.
-- @param self (optional) the snapshot to revert to. If not provided, it will revert to the last snapshot.
state.revert = function(self)
if not self then
-- no snapshot given, so move 1 up
self = current
if not self.previous then
-- top of list, no previous one, nothing to do
return
end
end
if getmetatable(self) ~= state_mt then
error("Value provided is not a valid snapshot", 2)
end
if self.next then
self.next:revert()
end
-- revert formatters in 'last'
self.formatters = {}
-- revert parameters in 'last'
self.parameters = {}
-- revert spies/stubs in 'last'
for s, _ in pairs(self.spies) do
self.spies[s] = nil
s:revert()
end
setmetatable(self, nil) -- invalidate as a snapshot
current = self.previous
current.next = nil
end
------------------------------------------------------
-- Creates a new snapshot.
-- @return snapshot table
state.snapshot = function()
local new = setmetatable({
formatters = {},
parameters = {},
spies = setmetatable({}, spies_mt),
previous = current,
revert = state.revert,
}, state_mt)
if current then
current.next = new
end
current = new
return current
end
-- FORMATTERS
state.add_formatter = function(callback)
table.insert(current.formatters, 1, callback)
end
state.remove_formatter = function(callback, s)
s = s or current
for i, v in ipairs(s.formatters) do
if v == callback then
table.remove(s.formatters, i)
break
end
end
-- wasn't found, so traverse up 1 state
if s.previous then
state.remove_formatter(callback, s.previous)
end
end
state.format_argument = function(val, s, fmtargs)
s = s or current
for _, fmt in ipairs(s.formatters) do
local valfmt = fmt(val, fmtargs)
if valfmt ~= nil then
return valfmt
end
end
-- nothing found, check snapshot 1 up in list
if s.previous then
return state.format_argument(val, s.previous, fmtargs)
end
return nil -- end of list, couldn't format
end
-- PARAMETERS
state.set_parameter = function(name, value)
if value == nil then
value = nilvalue
end
current.parameters[name] = value
end
state.get_parameter = function(name, s)
s = s or current
local val = s.parameters[name]
if val == nil and s.previous then
-- not found, so check 1 up in list
return state.get_parameter(name, s.previous)
end
if val ~= nilvalue then
return val
end
return nil
end
-- SPIES / STUBS
state.add_spy = function(spy)
current.spies[spy] = true
end
state.snapshot() -- create initial state
return state

@ -0,0 +1,109 @@
-- module will return a stub module table
local assert = require 'luassert.assert'
local spy = require 'luassert.spy'
local util = require 'luassert.util'
local unpack = util.unpack
local pack = util.pack
local stub = {}
function stub.new(object, key, ...)
if object == nil and key == nil then
-- called without arguments, create a 'blank' stub
object = {}
key = ""
end
local return_values = pack(...)
assert(type(object) == "table" and key ~= nil,
"stub.new(): Can only create stub on a table key, call with 2 params; table, key", util.errorlevel())
assert(object[key] == nil or util.callable(object[key]),
"stub.new(): The element for which to create a stub must either be callable, or be nil", util.errorlevel())
local old_elem = object[key] -- keep existing element (might be nil!)
local fn = (return_values.n == 1 and util.callable(return_values[1]) and return_values[1])
local defaultfunc = fn or function()
return unpack(return_values)
end
local oncalls = {}
local callbacks = {}
local stubfunc = function(...)
local args = util.make_arglist(...)
local match = util.matchoncalls(oncalls, args)
if match then
return callbacks[match](...)
end
return defaultfunc(...)
end
object[key] = stubfunc -- set the stubfunction
local s = spy.on(object, key) -- create a spy on top of the stub function
local spy_revert = s.revert -- keep created revert function
s.revert = function(self) -- wrap revert function to restore original element
if not self.reverted then
spy_revert(self)
object[key] = old_elem
self.reverted = true
end
return old_elem
end
s.returns = function(...)
local return_args = pack(...)
defaultfunc = function()
return unpack(return_args)
end
return s
end
s.invokes = function(func)
defaultfunc = function(...)
return func(...)
end
return s
end
s.by_default = {
returns = s.returns,
invokes = s.invokes,
}
s.on_call_with = function(...)
local match_args = util.make_arglist(...)
match_args = util.copyargs(match_args)
return {
returns = function(...)
local return_args = pack(...)
table.insert(oncalls, match_args)
callbacks[match_args] = function()
return unpack(return_args)
end
return s
end,
invokes = function(func)
table.insert(oncalls, match_args)
callbacks[match_args] = function(...)
return func(...)
end
return s
end,
}
end
return s
end
local function set_stub(state, arguments)
state.payload = arguments[1]
state.failure_message = arguments[2]
end
assert:register("modifier", "stub", set_stub)
return setmetatable(stub, {
__call = function(self, ...)
-- stub originally was a function only. Now that it is a module table
-- the __call method is required for backward compatibility
return stub.new(...)
end,
})

@ -0,0 +1,386 @@
local util = {}
local arglist_mt = {}
-- have pack/unpack both respect the 'n' field
local _unpack = table.unpack or unpack
local unpack = function(t, i, j)
return _unpack(t, i or 1, j or t.n or #t)
end
local pack = function(...)
return {
n = select("#", ...),
...,
}
end
util.pack = pack
util.unpack = unpack
function util.deepcompare(t1, t2, ignore_mt, cycles, thresh1, thresh2)
local ty1 = type(t1)
local ty2 = type(t2)
-- non-table types can be directly compared
if ty1 ~= 'table' or ty2 ~= 'table' then
return t1 == t2
end
local mt1 = debug.getmetatable(t1)
local mt2 = debug.getmetatable(t2)
-- would equality be determined by metatable __eq?
if mt1 and mt1 == mt2 and mt1.__eq then
-- then use that unless asked not to
if not ignore_mt then
return t1 == t2
end
else -- we can skip the deep comparison below if t1 and t2 share identity
if rawequal(t1, t2) then
return true
end
end
-- handle recursive tables
cycles = cycles or {{}, {}}
thresh1, thresh2 = (thresh1 or 1), (thresh2 or 1)
cycles[1][t1] = (cycles[1][t1] or 0)
cycles[2][t2] = (cycles[2][t2] or 0)
if cycles[1][t1] == 1 or cycles[2][t2] == 1 then
thresh1 = cycles[1][t1] + 1
thresh2 = cycles[2][t2] + 1
end
if cycles[1][t1] > thresh1 and cycles[2][t2] > thresh2 then
return true
end
cycles[1][t1] = cycles[1][t1] + 1
cycles[2][t2] = cycles[2][t2] + 1
for k1, v1 in next, t1 do
local v2 = t2[k1]
if v2 == nil then
return false, {k1}
end
local same, crumbs = util.deepcompare(v1, v2, nil, cycles, thresh1, thresh2)
if not same then
crumbs = crumbs or {}
table.insert(crumbs, k1)
return false, crumbs
end
end
for k2, _ in next, t2 do
-- only check whether each element has a t1 counterpart, actual comparison
-- has been done in first loop above
if t1[k2] == nil then
return false, {k2}
end
end
cycles[1][t1] = cycles[1][t1] - 1
cycles[2][t2] = cycles[2][t2] - 1
return true
end
function util.shallowcopy(t)
if type(t) ~= "table" then
return t
end
local copy = {}
setmetatable(copy, getmetatable(t))
for k, v in next, t do
copy[k] = v
end
return copy
end
function util.deepcopy(t, deepmt, cache)
local spy = require 'luassert.spy'
if type(t) ~= "table" then
return t
end
local copy = {}
-- handle recursive tables
local cache = cache or {}
if cache[t] then
return cache[t]
end
cache[t] = copy
for k, v in next, t do
copy[k] = (spy.is_spy(v) and v or util.deepcopy(v, deepmt, cache))
end
if deepmt then
debug.setmetatable(copy, util.deepcopy(debug.getmetatable(t, nil, cache)))
else
debug.setmetatable(copy, debug.getmetatable(t))
end
return copy
end
-----------------------------------------------
-- Copies arguments as a list of arguments
-- @param args the arguments of which to copy
-- @return the copy of the arguments
function util.copyargs(args)
local copy = {}
setmetatable(copy, getmetatable(args))
local match = require 'luassert.match'
local spy = require 'luassert.spy'
for k, v in pairs(args) do
copy[k] = ((match.is_matcher(v) or spy.is_spy(v)) and v or util.deepcopy(v))
end
return {
vals = copy,
refs = util.shallowcopy(args),
}
end
-----------------------------------------------
-- Clear an arguments or return values list from a table
-- @param arglist the table to clear of arguments or return values and their count
-- @return No return values
function util.cleararglist(arglist)
for idx = arglist.n, 1, -1 do
util.tremove(arglist, idx)
end
arglist.n = nil
end
-----------------------------------------------
-- Test specs against an arglist in deepcopy and refs flavours.
-- @param args deepcopy arglist
-- @param argsrefs refs arglist
-- @param specs arguments/return values to match against args/argsrefs
-- @return true if specs match args/argsrefs, false otherwise
local function matcharg(args, argrefs, specs)
local match = require 'luassert.match'
for idx, argval in pairs(args) do
local spec = specs[idx]
if match.is_matcher(spec) then
if match.is_ref_matcher(spec) then
argval = argrefs[idx]
end
if not spec(argval) then
return false
end
elseif (spec == nil or not util.deepcompare(argval, spec)) then
return false
end
end
for idx, spec in pairs(specs) do
-- only check whether each element has an args counterpart,
-- actual comparison has been done in first loop above
local argval = args[idx]
if argval == nil then
-- no args counterpart, so try to compare using matcher
if match.is_matcher(spec) then
if not spec(argval) then
return false
end
else
return false
end
end
end
return true
end
-----------------------------------------------
-- Find matching arguments/return values in a saved list of
-- arguments/returned values.
-- @param invocations_list list of arguments/returned values to search (list of lists)
-- @param specs arguments/return values to match against argslist
-- @return the last matching arguments/returned values if a match is found, otherwise nil
function util.matchargs(invocations_list, specs)
-- Search the arguments/returned values last to first to give the
-- most helpful answer possible. In the cases where you can place
-- your assertions between calls to check this gives you the best
-- information if no calls match. In the cases where you can't do
-- that there is no good way to predict what would work best.
assert(not util.is_arglist(invocations_list), "expected a list of arglist-object, got an arglist")
for ii = #invocations_list, 1, -1 do
local val = invocations_list[ii]
if matcharg(val.vals, val.refs, specs) then
return val
end
end
return nil
end
-----------------------------------------------
-- Find matching oncall for an actual call.
-- @param oncalls list of oncalls to search
-- @param args actual call argslist to match against
-- @return the first matching oncall if a match is found, otherwise nil
function util.matchoncalls(oncalls, args)
for _, callspecs in ipairs(oncalls) do
-- This lookup is done immediately on *args* passing into the stub
-- so pass *args* as both *args* and *argsref* without copying
-- either.
if matcharg(args, args, callspecs.vals) then
return callspecs
end
end
return nil
end
-----------------------------------------------
-- table.insert() replacement that respects nil values.
-- The function will use table field 'n' as indicator of the
-- table length, if not set, it will be added.
-- @param t table into which to insert
-- @param pos (optional) position in table where to insert. NOTE: not optional if you want to insert a nil-value!
-- @param val value to insert
-- @return No return values
function util.tinsert(...)
-- check optional POS value
local args = {...}
local c = select('#', ...)
local t = args[1]
local pos = args[2]
local val = args[3]
if c < 3 then
val = pos
pos = nil
end
-- set length indicator n if not present (+1)
t.n = (t.n or #t) + 1
if not pos then
pos = t.n
elseif pos > t.n then
-- out of our range
t[pos] = val
t.n = pos
end
-- shift everything up 1 pos
for i = t.n, pos + 1, -1 do
t[i] = t[i - 1]
end
-- add element to be inserted
t[pos] = val
end
-----------------------------------------------
-- table.remove() replacement that respects nil values.
-- The function will use table field 'n' as indicator of the
-- table length, if not set, it will be added.
-- @param t table from which to remove
-- @param pos (optional) position in table to remove
-- @return No return values
function util.tremove(t, pos)
-- set length indicator n if not present (+1)
t.n = t.n or #t
if not pos then
pos = t.n
elseif pos > t.n then
local removed = t[pos]
-- out of our range
t[pos] = nil
return removed
end
local removed = t[pos]
-- shift everything up 1 pos
for i = pos, t.n do
t[i] = t[i + 1]
end
-- set size, clean last
t[t.n] = nil
t.n = t.n - 1
return removed
end
-----------------------------------------------
-- Checks an element to be callable.
-- The type must either be a function or have a metatable
-- containing an '__call' function.
-- @param object element to inspect on being callable or not
-- @return boolean, true if the object is callable
function util.callable(object)
return type(object) == "function" or type((debug.getmetatable(object) or {}).__call) == "function"
end
-----------------------------------------------
-- Checks an element has tostring.
-- The type must either be a string or have a metatable
-- containing an '__tostring' function.
-- @param object element to inspect on having tostring or not
-- @return boolean, true if the object has tostring
function util.hastostring(object)
return type(object) == "string" or type((debug.getmetatable(object) or {}).__tostring) == "function"
end
-----------------------------------------------
-- Find the first level, not defined in the same file as the caller's
-- code file to properly report an error.
-- @param level the level to use as the caller's source file
-- @return number, the level of which to report an error
function util.errorlevel(level)
local level = (level or 1) + 1 -- add one to get level of the caller
local info = debug.getinfo(level)
local source = (info or {}).source
local file = source
while file and (file == source or source == "=(tail call)") do
level = level + 1
info = debug.getinfo(level)
source = (info or {}).source
end
if level > 1 then
level = level - 1
end -- deduct call to errorlevel() itself
return level
end
-----------------------------------------------
-- Extract modifier and namespace keys from list of tokens.
-- @param nspace the namespace from which to match tokens
-- @param tokens list of tokens to search for keys
-- @return table, list of keys that were extracted
function util.extract_keys(nspace, tokens)
local namespace = require 'luassert.namespaces'
-- find valid keys by coalescing tokens as needed, starting from the end
local keys = {}
local key = nil
local i = #tokens
while i > 0 do
local token = tokens[i]
key = key and (token .. '_' .. key) or token
-- find longest matching key in the given namespace
local longkey = i > 1 and (tokens[i - 1] .. '_' .. key) or nil
while i > 1 and longkey and namespace[nspace][longkey] do
key = longkey
i = i - 1
token = tokens[i]
longkey = (token .. '_' .. key)
end
if namespace.modifier[key] or namespace[nspace][key] then
table.insert(keys, 1, key)
key = nil
end
i = i - 1
end
-- if there's anything left we didn't recognize it
if key then
error("luassert: unknown modifier/" .. nspace .. ": '" .. key .. "'", util.errorlevel(2))
end
return keys
end
-----------------------------------------------
-- store argument list for return values of a function in a table.
-- The table will get a metatable to identify it as an arglist
function util.make_arglist(...)
local arglist = {...}
arglist.n = select('#', ...) -- add values count for trailing nils
return setmetatable(arglist, arglist_mt)
end
-----------------------------------------------
-- check a table to be an arglist type.
function util.is_arglist(object)
return getmetatable(object) == arglist_mt
end
return util
Loading…
Cancel
Save