🔧 build: 增加 luassert 测试 库
parent
4e3704ac97
commit
803f7326c7
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,473 @@
|
||||
--[[
|
||||
Minimal test framework for Lua.
|
||||
lester - v0.1.2 - 15/Feb/2021
|
||||
Eduardo Bart - edub4rt@gmail.com
|
||||
https://github.com/edubart/lester
|
||||
Minimal Lua test framework.
|
||||
See end of file for LICENSE.
|
||||
]] --[[--
|
||||
Lester is a minimal unit testing framework for Lua with a focus on being simple to use.
|
||||
|
||||
## Features
|
||||
|
||||
* Minimal, just one file.
|
||||
* Self contained, no external dependencies.
|
||||
* Simple and hackable when needed.
|
||||
* Use `describe` and `it` blocks to describe tests.
|
||||
* Supports `before` and `after` handlers.
|
||||
* Colored output.
|
||||
* Configurable via the script or with environment variables.
|
||||
* Quiet mode, to use in live development.
|
||||
* Optionally filter tests by name.
|
||||
* Show traceback on errors.
|
||||
* Show time to complete tests.
|
||||
* Works with Lua 5.1+.
|
||||
* Efficient.
|
||||
|
||||
## Usage
|
||||
|
||||
Copy `lester.lua` file to a project and require it,
|
||||
which returns a table that includes all of the functionality:
|
||||
|
||||
```lua
|
||||
local lester = require 'lester'
|
||||
local describe, it, expect = lester.describe, lester.it, lester.expect
|
||||
|
||||
-- Customize lester configuration.
|
||||
lester.show_traceback = false
|
||||
|
||||
describe('my project', function()
|
||||
lester.before(function()
|
||||
-- This function is run before every test.
|
||||
end)
|
||||
|
||||
describe('module1', function() -- Describe blocks can be nested.
|
||||
it('feature1', function()
|
||||
expect.equal('something', 'something') -- Pass.
|
||||
end)
|
||||
|
||||
it('feature2', function()
|
||||
expect.truthy(false) -- Fail.
|
||||
end)
|
||||
end)
|
||||
end)
|
||||
|
||||
lester.report() -- Print overall statistic of the tests run.
|
||||
lester.exit() -- Exit with success if all tests passed.
|
||||
```
|
||||
|
||||
## Customizing output with environment variables
|
||||
|
||||
To customize the output of lester externally,
|
||||
you can set the following environment variables before running a test suite:
|
||||
|
||||
* `LESTER_QUIET="true"`, omit print of passed tests.
|
||||
* `LESTER_COLORED="false"`, disable colored output.
|
||||
* `LESTER_SHOW_TRACEBACK="false"`, disable traceback on test failures.
|
||||
* `LESTER_SHOW_ERROR="false"`, omit print of error description of failed tests.
|
||||
* `LESTER_STOP_ON_FAIL="true"`, stop on first test failure.
|
||||
* `LESTER_UTF8TERM="false"`, disable printing of UTF-8 characters.
|
||||
* `LESTER_FILTER="some text"`, filter the tests that should be run.
|
||||
|
||||
Note that these configurations can be changed via script too, check the documentation.
|
||||
|
||||
]] -- Returns whether the terminal supports UTF-8 characters.
|
||||
local function is_utf8term()
|
||||
local lang = os.getenv('LANG')
|
||||
return (lang and lang:lower():match('utf%-8$')) and true or false
|
||||
end
|
||||
|
||||
-- Returns whether a system environment variable is "true".
|
||||
local function getboolenv(varname, default)
|
||||
local val = os.getenv(varname)
|
||||
if val == 'true' then
|
||||
return true
|
||||
elseif val == 'false' then
|
||||
return false
|
||||
end
|
||||
return default
|
||||
end
|
||||
|
||||
-- The lester module.
|
||||
local lester = {
|
||||
--- Weather lines of passed tests should not be printed. False by default.
|
||||
quiet = getboolenv('LESTER_QUIET', false),
|
||||
--- Weather the output should be colorized. True by default.
|
||||
colored = getboolenv('LESTER_COLORED', true),
|
||||
--- Weather a traceback must be shown on test failures. True by default.
|
||||
show_traceback = getboolenv('LESTER_SHOW_TRACEBACK', true),
|
||||
--- Weather the error description of a test failure should be shown. True by default.
|
||||
show_error = getboolenv('LESTER_SHOW_ERROR', true),
|
||||
--- Weather test suite should exit on first test failure. False by default.
|
||||
stop_on_fail = getboolenv('LESTER_STOP_ON_FAIL', false),
|
||||
--- Weather we can print UTF-8 characters to the terminal. True by default when supported.
|
||||
utf8term = getboolenv('LESTER_UTF8TERM', is_utf8term()),
|
||||
--- A string with a lua pattern to filter tests. Nil by default.
|
||||
filter = os.getenv('LESTER_FILTER'),
|
||||
--- Function to retrieve time in seconds with milliseconds precision, `os.clock` by default.
|
||||
seconds = os.clock,
|
||||
}
|
||||
|
||||
-- Variables used internally for the lester state.
|
||||
local lester_start = nil
|
||||
local last_succeeded = false
|
||||
local level = 0
|
||||
local successes = 0
|
||||
local total_successes = 0
|
||||
local failures = 0
|
||||
local total_failures = 0
|
||||
local start = 0
|
||||
local befores = {}
|
||||
local afters = {}
|
||||
local names = {}
|
||||
|
||||
-- Color codes.
|
||||
local color_codes = {
|
||||
reset = string.char(27) .. '[0m',
|
||||
bright = string.char(27) .. '[1m',
|
||||
red = string.char(27) .. '[31m',
|
||||
green = string.char(27) .. '[32m',
|
||||
blue = string.char(27) .. '[34m',
|
||||
magenta = string.char(27) .. '[35m',
|
||||
}
|
||||
|
||||
-- Colors table, returning proper color code if colored mode is enabled.
|
||||
local colors = setmetatable({}, {
|
||||
__index = function(_, key)
|
||||
return lester.colored and color_codes[key] or ''
|
||||
end,
|
||||
})
|
||||
|
||||
--- Table of terminal colors codes, can be customized.
|
||||
lester.colors = colors
|
||||
|
||||
--- Describe a block of tests, which consists in a set of tests.
|
||||
-- Describes can be nested.
|
||||
-- @param name A string used to describe the block.
|
||||
-- @param func A function containing all the tests or other describes.
|
||||
function lester.describe(name, func)
|
||||
if level == 0 then -- Get start time for top level describe blocks.
|
||||
start = lester.seconds()
|
||||
if not lester_start then
|
||||
lester_start = start
|
||||
end
|
||||
end
|
||||
-- Setup describe block variables.
|
||||
failures = 0
|
||||
successes = 0
|
||||
level = level + 1
|
||||
names[level] = name
|
||||
-- Run the describe block.
|
||||
func()
|
||||
-- Cleanup describe block.
|
||||
afters[level] = nil
|
||||
befores[level] = nil
|
||||
names[level] = nil
|
||||
level = level - 1
|
||||
-- Pretty print statistics for top level describe block.
|
||||
if level == 0 and not lester.quiet and (successes > 0 or failures > 0) then
|
||||
local io_write = io.write
|
||||
local colors_reset, colors_green = colors.reset, colors.green
|
||||
io_write(failures == 0 and colors_green or colors.red, '[====] ', colors.magenta, name, colors_reset, ' | ',
|
||||
colors_green, successes, colors_reset, ' successes / ')
|
||||
if failures > 0 then
|
||||
io_write(colors.red, failures, colors_reset, ' failures / ')
|
||||
end
|
||||
io_write(colors.bright, string.format('%.6f', lester.seconds() - start), colors_reset, ' seconds\n')
|
||||
end
|
||||
end
|
||||
|
||||
-- Error handler used to get traceback for errors.
|
||||
local function xpcall_error_handler(err)
|
||||
return debug.traceback(tostring(err), 2)
|
||||
end
|
||||
|
||||
-- Pretty print the line on the test file where an error happened.
|
||||
local function show_error_line(err)
|
||||
local info = debug.getinfo(3)
|
||||
local io_write = io.write
|
||||
local colors_reset = colors.reset
|
||||
local short_src, currentline = info.short_src, info.currentline
|
||||
io_write(' (', colors.blue, short_src, colors_reset, ':', colors.bright, currentline, colors_reset)
|
||||
if err and lester.show_traceback then
|
||||
local fnsrc = short_src .. ':' .. currentline
|
||||
for cap1, cap2 in err:gmatch('\t[^\n:]+:(%d+): in function <([^>]+)>\n') do
|
||||
if cap2 == fnsrc then
|
||||
io_write('/', colors.bright, cap1, colors_reset)
|
||||
break
|
||||
end
|
||||
end
|
||||
end
|
||||
io_write(')')
|
||||
end
|
||||
|
||||
-- Pretty print the test name, with breadcrumb for the describe blocks.
|
||||
local function show_test_name(name)
|
||||
local io_write = io.write
|
||||
local colors_reset = colors.reset
|
||||
for _, descname in ipairs(names) do
|
||||
io_write(colors.magenta, descname, colors_reset, ' | ')
|
||||
end
|
||||
io_write(colors.bright, name, colors_reset)
|
||||
end
|
||||
|
||||
--- Declare a test, which consists of a set of assertions.
|
||||
-- @param name A name for the test.
|
||||
-- @param func The function containing all assertions.
|
||||
function lester.it(name, func)
|
||||
-- Skip the test if it does not match the filter.
|
||||
if lester.filter then
|
||||
local fullname = table.concat(names, ' | ') .. ' | ' .. name
|
||||
if not fullname:match(lester.filter) then
|
||||
return
|
||||
end
|
||||
end
|
||||
-- Execute before handlers.
|
||||
for _, levelbefores in ipairs(befores) do
|
||||
for _, beforefn in ipairs(levelbefores) do
|
||||
beforefn(name)
|
||||
end
|
||||
end
|
||||
-- Run the test, capturing errors if any.
|
||||
local success, err
|
||||
if lester.show_traceback then
|
||||
success, err = xpcall(func, xpcall_error_handler)
|
||||
else
|
||||
success, err = pcall(func)
|
||||
if not success and err then
|
||||
err = tostring(err)
|
||||
end
|
||||
end
|
||||
-- Count successes and failures.
|
||||
if success then
|
||||
successes = successes + 1
|
||||
total_successes = total_successes + 1
|
||||
else
|
||||
failures = failures + 1
|
||||
total_failures = total_failures + 1
|
||||
end
|
||||
local io_write = io.write
|
||||
local colors_reset = colors.reset
|
||||
-- Print the test run.
|
||||
if not lester.quiet then -- Show test status and complete test name.
|
||||
if success then
|
||||
io_write(colors.green, '[PASS] ', colors_reset)
|
||||
else
|
||||
io_write(colors.red, '[FAIL] ', colors_reset)
|
||||
end
|
||||
show_test_name(name)
|
||||
if not success then
|
||||
show_error_line(err)
|
||||
end
|
||||
io_write('\n')
|
||||
else
|
||||
if success then -- Show just a character hinting that the test succeeded.
|
||||
local o = (lester.utf8term and lester.colored) and string.char(226, 151, 143) or 'o'
|
||||
io_write(colors.green, o, colors_reset)
|
||||
else -- Show complete test name on failure.
|
||||
io_write(last_succeeded and '\n' or '', colors.red, '[FAIL] ', colors_reset)
|
||||
show_test_name(name)
|
||||
show_error_line(err)
|
||||
io_write('\n')
|
||||
end
|
||||
end
|
||||
-- Print error message, colorizing its output if possible.
|
||||
if err and lester.show_error then
|
||||
if lester.colored then
|
||||
local errfile, errline, errmsg, rest = err:match('^([^:\n]+):(%d+): ([^\n]+)(.*)')
|
||||
if errfile and errline and errmsg and rest then
|
||||
io_write(colors.blue, errfile, colors_reset, ':', colors.bright, errline, colors_reset, ': ')
|
||||
if errmsg:match('^%w([^:]*)$') then
|
||||
io_write(colors.red, errmsg, colors_reset)
|
||||
else
|
||||
io_write(errmsg)
|
||||
end
|
||||
err = rest
|
||||
end
|
||||
end
|
||||
io_write(err, '\n\n')
|
||||
end
|
||||
io.flush()
|
||||
-- Stop on failure.
|
||||
if not success and lester.stop_on_fail then
|
||||
if lester.quiet then
|
||||
io_write('\n')
|
||||
io.flush()
|
||||
end
|
||||
lester.exit()
|
||||
end
|
||||
-- Execute after handlers.
|
||||
for _, levelafters in ipairs(afters) do
|
||||
for _, afterfn in ipairs(levelafters) do
|
||||
afterfn(name)
|
||||
end
|
||||
end
|
||||
last_succeeded = success
|
||||
end
|
||||
|
||||
--- Set a function that is called before every test inside a describe block.
|
||||
-- A single string containing the name of the test about to be run will be passed to `func`.
|
||||
function lester.before(func)
|
||||
local levelbefores = befores[level]
|
||||
if not levelbefores then
|
||||
levelbefores = {}
|
||||
befores[level] = levelbefores
|
||||
end
|
||||
levelbefores[#levelbefores + 1] = func
|
||||
end
|
||||
|
||||
--- Set a function that is called after every test inside a describe block.
|
||||
-- A single string containing the name of the test that was finished will be passed to `func`.
|
||||
-- The function is executed independently if the test passed or failed.
|
||||
function lester.after(func)
|
||||
local levelafters = afters[level]
|
||||
if not levelafters then
|
||||
levelafters = {}
|
||||
afters[level] = levelafters
|
||||
end
|
||||
levelafters[#levelafters + 1] = func
|
||||
end
|
||||
|
||||
--- Pretty print statistics of all test runs.
|
||||
-- With total success, total failures and run time in seconds.
|
||||
function lester.report()
|
||||
local now = lester.seconds()
|
||||
local colors_reset = colors.reset
|
||||
io.write(lester.quiet and '\n' or '', colors.green, total_successes, colors_reset, ' successes / ', colors.red,
|
||||
total_failures, colors_reset, ' failures / ', colors.bright, string.format('%.6f', now - (lester_start or now)),
|
||||
colors_reset, ' seconds\n')
|
||||
io.flush()
|
||||
return total_failures == 0
|
||||
end
|
||||
|
||||
--- Exit the application with success code if all tests passed, or failure code otherwise.
|
||||
function lester.exit()
|
||||
os.exit(total_failures == 0)
|
||||
end
|
||||
|
||||
local expect = {}
|
||||
--- Expect module, containing utility function for doing assertions inside a test.
|
||||
lester.expect = expect
|
||||
|
||||
--- Check if a function fails with an error.
|
||||
-- If `expected` is nil then any error is accepted.
|
||||
-- If `expected` is a string then we check if the error contains that string.
|
||||
-- If `expected` is anything else then we check if both are equal.
|
||||
function expect.fail(func, expected)
|
||||
local ok, err = pcall(func)
|
||||
if ok then
|
||||
error('expected function to fail', 2)
|
||||
elseif expected ~= nil then
|
||||
local found = expected == err
|
||||
if not found and type(expected) == 'string' then
|
||||
found = string.find(tostring(err), expected, 1, true)
|
||||
end
|
||||
if not found then
|
||||
error('expected function to fail\nexpected:\n' .. tostring(expected) .. '\ngot:\n' .. tostring(err), 2)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
--- Check if a function does not fail with a error.
|
||||
function expect.not_fail(func)
|
||||
local ok, err = pcall(func)
|
||||
if not ok then
|
||||
error('expected function to not fail\ngot error:\n' .. tostring(err), 2)
|
||||
end
|
||||
end
|
||||
|
||||
--- Check if a value is not `nil`.
|
||||
function expect.exist(v)
|
||||
if v == nil then
|
||||
error('expected value to exist\ngot:\n' .. tostring(v), 2)
|
||||
end
|
||||
end
|
||||
|
||||
--- Check if a value is `nil`.
|
||||
function expect.not_exist(v)
|
||||
if v ~= nil then
|
||||
error('expected value to not exist\ngot:\n' .. tostring(v), 2)
|
||||
end
|
||||
end
|
||||
|
||||
--- Check if an expression is evaluates to `true`.
|
||||
function expect.truthy(v)
|
||||
if not v then
|
||||
error('expected expression to be true\ngot:\n' .. tostring(v), 2)
|
||||
end
|
||||
end
|
||||
|
||||
--- Check if an expression is evaluates to `false`.
|
||||
function expect.falsy(v)
|
||||
if v then
|
||||
error('expected expression to be false\ngot:\n' .. tostring(v), 2)
|
||||
end
|
||||
end
|
||||
|
||||
--- Compare if two values are equal, considering nested tables.
|
||||
local function strict_eq(t1, t2)
|
||||
if rawequal(t1, t2) then
|
||||
return true
|
||||
end
|
||||
if type(t1) ~= type(t2) then
|
||||
return false
|
||||
end
|
||||
if type(t1) ~= 'table' then
|
||||
return t1 == t2
|
||||
end
|
||||
if getmetatable(t1) ~= getmetatable(t2) then
|
||||
return false
|
||||
end
|
||||
for k, v1 in pairs(t1) do
|
||||
if not strict_eq(v1, t2[k]) then
|
||||
return false
|
||||
end
|
||||
end
|
||||
for k, v2 in pairs(t2) do
|
||||
if not strict_eq(v2, t1[k]) then
|
||||
return false
|
||||
end
|
||||
end
|
||||
return true
|
||||
end
|
||||
|
||||
--- Check if two values are equal.
|
||||
function expect.equal(v1, v2)
|
||||
if not strict_eq(v1, v2) then
|
||||
error('expected values to be equal\nfirst value:\n' .. tostring(v1) .. '\nsecond value:\n' .. tostring(v2), 2)
|
||||
end
|
||||
end
|
||||
|
||||
--- Check if two values are not equal.
|
||||
function expect.not_equal(v1, v2)
|
||||
if strict_eq(v1, v2) then
|
||||
error('expected values to be not equal\nfirst value:\n' .. tostring(v1) .. '\nsecond value:\n' .. tostring(v2),
|
||||
2)
|
||||
end
|
||||
end
|
||||
|
||||
return lester
|
||||
|
||||
--[[
|
||||
The MIT License (MIT)
|
||||
|
||||
Copyright (c) 2021 Eduardo Bart (https://github.com/edubart)
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
]]
|
||||
@ -0,0 +1,70 @@
|
||||
local assert = require('luassert.assert')
|
||||
local say = require('luassert.say')
|
||||
|
||||
-- Example usage:
|
||||
-- local arr = { "one", "two", "three" }
|
||||
--
|
||||
-- assert.array(arr).has.no.holes() -- checks the array to not contain holes --> passes
|
||||
-- assert.array(arr).has.no.holes(4) -- sets explicit length to 4 --> fails
|
||||
--
|
||||
-- local first_hole = assert.array(arr).has.holes(4) -- check array of size 4 to contain holes --> passes
|
||||
-- assert.equal(4, first_hole) -- passes, as the index of the first hole is returned
|
||||
|
||||
|
||||
-- Unique key to store the object we operate on in the state object
|
||||
-- key must be unique, to make sure we do not have name collissions in the shared state object
|
||||
local ARRAY_STATE_KEY = "__array_state"
|
||||
|
||||
-- The modifier, to store the object in our state
|
||||
local function array(state, args, level)
|
||||
assert(args.n > 0, "No array provided to the array-modifier")
|
||||
assert(rawget(state, ARRAY_STATE_KEY) == nil, "Array already set")
|
||||
rawset(state, ARRAY_STATE_KEY, args[1])
|
||||
return state
|
||||
end
|
||||
|
||||
-- The actual assertion that operates on our object, stored via the modifier
|
||||
local function holes(state, args, level)
|
||||
local length = args[1]
|
||||
local arr = rawget(state, ARRAY_STATE_KEY) -- retrieve previously set object
|
||||
-- only check against nil, metatable types are allowed
|
||||
assert(arr ~= nil, "No array set, please use the array modifier to set the array to validate")
|
||||
if length == nil then
|
||||
length = 0
|
||||
for i in pairs(arr) do
|
||||
if type(i) == "number" and
|
||||
i > length and
|
||||
math.floor(i) == i then
|
||||
length = i
|
||||
end
|
||||
end
|
||||
end
|
||||
assert(type(length) == "number", "expected array length to be of type 'number', got: "..tostring(length))
|
||||
-- let's do the actual assertion
|
||||
local missing
|
||||
for i = 1, length do
|
||||
if arr[i] == nil then
|
||||
missing = i
|
||||
break
|
||||
end
|
||||
end
|
||||
-- format arguments for output strings;
|
||||
args[1] = missing
|
||||
args.n = missing and 1 or 0
|
||||
return missing ~= nil, { missing } -- assert result + first missing index as return value
|
||||
end
|
||||
|
||||
-- Register the proper assertion messages
|
||||
say:set("assertion.array_holes.positive", [[
|
||||
Expected array to have holes, but none was found.
|
||||
]])
|
||||
say:set("assertion.array_holes.negative", [[
|
||||
Expected array to not have holes, hole found at position: %s
|
||||
]])
|
||||
|
||||
-- Register the assertion, and the modifier
|
||||
assert:register("assertion", "holes", holes,
|
||||
"assertion.array_holes.positive",
|
||||
"assertion.array_holes.negative")
|
||||
|
||||
assert:register("modifier", "array", array)
|
||||
@ -0,0 +1,180 @@
|
||||
local s = require 'luassert.say'
|
||||
local astate = require 'luassert.state'
|
||||
local util = require 'luassert.util'
|
||||
local unpack = util.unpack
|
||||
local obj -- the returned module table
|
||||
local level_mt = {}
|
||||
|
||||
-- list of namespaces
|
||||
local namespace = require 'luassert.namespaces'
|
||||
|
||||
local function geterror(assertion_message, failure_message, args)
|
||||
if util.hastostring(failure_message) then
|
||||
failure_message = tostring(failure_message)
|
||||
elseif failure_message ~= nil then
|
||||
failure_message = astate.format_argument(failure_message)
|
||||
end
|
||||
local message = s(assertion_message, obj:format(args))
|
||||
if message and failure_message then
|
||||
message = failure_message .. "\n" .. message
|
||||
end
|
||||
return message or failure_message
|
||||
end
|
||||
|
||||
local __state_meta = {
|
||||
|
||||
__call = function(self, ...)
|
||||
local keys = util.extract_keys("assertion", self.tokens)
|
||||
|
||||
local assertion
|
||||
|
||||
for _, key in ipairs(keys) do
|
||||
assertion = namespace.assertion[key] or assertion
|
||||
end
|
||||
|
||||
if assertion then
|
||||
for _, key in ipairs(keys) do
|
||||
if namespace.modifier[key] then
|
||||
namespace.modifier[key].callback(self)
|
||||
end
|
||||
end
|
||||
|
||||
local arguments = util.make_arglist(...)
|
||||
local val, retargs = assertion.callback(self, arguments, util.errorlevel())
|
||||
|
||||
if not val == self.mod then
|
||||
local message = assertion.positive_message
|
||||
if not self.mod then
|
||||
message = assertion.negative_message
|
||||
end
|
||||
local err = geterror(message, rawget(self,"failure_message"), arguments)
|
||||
error(err or "assertion failed!", util.errorlevel())
|
||||
end
|
||||
|
||||
if retargs then
|
||||
return unpack(retargs)
|
||||
end
|
||||
return ...
|
||||
else
|
||||
local arguments = util.make_arglist(...)
|
||||
self.tokens = {}
|
||||
|
||||
for _, key in ipairs(keys) do
|
||||
if namespace.modifier[key] then
|
||||
namespace.modifier[key].callback(self, arguments, util.errorlevel())
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
return self
|
||||
end,
|
||||
|
||||
__index = function(self, key)
|
||||
for token in key:lower():gmatch('[^_]+') do
|
||||
table.insert(self.tokens, token)
|
||||
end
|
||||
|
||||
return self
|
||||
end
|
||||
}
|
||||
|
||||
obj = {
|
||||
state = function() return setmetatable({mod=true, tokens={}}, __state_meta) end,
|
||||
|
||||
-- registers a function in namespace
|
||||
register = function(self, nspace, name, callback, positive_message, negative_message)
|
||||
local lowername = name:lower()
|
||||
if not namespace[nspace] then
|
||||
namespace[nspace] = {}
|
||||
end
|
||||
namespace[nspace][lowername] = {
|
||||
callback = callback,
|
||||
name = lowername,
|
||||
positive_message=positive_message,
|
||||
negative_message=negative_message
|
||||
}
|
||||
end,
|
||||
|
||||
-- unregisters a function in a namespace
|
||||
unregister = function(self, nspace, name)
|
||||
local lowername = name:lower()
|
||||
if not namespace[nspace] then
|
||||
namespace[nspace] = {}
|
||||
end
|
||||
namespace[nspace][lowername] = nil
|
||||
end,
|
||||
|
||||
-- registers a formatter
|
||||
-- a formatter takes a single argument, and converts it to a string, or returns nil if it cannot format the argument
|
||||
add_formatter = function(self, callback)
|
||||
astate.add_formatter(callback)
|
||||
end,
|
||||
|
||||
-- unregisters a formatter
|
||||
remove_formatter = function(self, fmtr)
|
||||
astate.remove_formatter(fmtr)
|
||||
end,
|
||||
|
||||
format = function(self, args)
|
||||
-- args.n specifies the number of arguments in case of 'trailing nil' arguments which get lost
|
||||
local nofmt = args.nofmt or {} -- arguments in this list should not be formatted
|
||||
local fmtargs = args.fmtargs or {} -- additional arguments to be passed to formatter
|
||||
for i = 1, (args.n or #args) do -- cannot use pairs because table might have nils
|
||||
if not nofmt[i] then
|
||||
local val = args[i]
|
||||
local valfmt = astate.format_argument(val, nil, fmtargs[i])
|
||||
if valfmt == nil then valfmt = tostring(val) end -- no formatter found
|
||||
args[i] = valfmt
|
||||
end
|
||||
end
|
||||
return args
|
||||
end,
|
||||
|
||||
set_parameter = function(self, name, value)
|
||||
astate.set_parameter(name, value)
|
||||
end,
|
||||
|
||||
get_parameter = function(self, name)
|
||||
return astate.get_parameter(name)
|
||||
end,
|
||||
|
||||
add_spy = function(self, spy)
|
||||
astate.add_spy(spy)
|
||||
end,
|
||||
|
||||
snapshot = function(self)
|
||||
return astate.snapshot()
|
||||
end,
|
||||
|
||||
level = function(self, level)
|
||||
return setmetatable({
|
||||
level = level
|
||||
}, level_mt)
|
||||
end,
|
||||
|
||||
-- returns the level if a level-value, otherwise nil
|
||||
get_level = function(self, level)
|
||||
if getmetatable(level) ~= level_mt then
|
||||
return nil -- not a valid error-level
|
||||
end
|
||||
return level.level
|
||||
end,
|
||||
}
|
||||
|
||||
local __meta = {
|
||||
|
||||
__call = function(self, bool, message, level, ...)
|
||||
if not bool then
|
||||
local err_level = (self:get_level(level) or 1) + 1
|
||||
error(message or "assertion failed!", err_level)
|
||||
end
|
||||
return bool , message , level , ...
|
||||
end,
|
||||
|
||||
__index = function(self, key)
|
||||
return rawget(self, key) or self.state()[key]
|
||||
end,
|
||||
|
||||
}
|
||||
|
||||
return setmetatable(obj, __meta)
|
||||
@ -0,0 +1,328 @@
|
||||
-- module will not return anything, only register assertions with the main assert engine
|
||||
|
||||
-- assertions take 2 parameters;
|
||||
-- 1) state
|
||||
-- 2) arguments list. The list has a member 'n' with the argument count to check for trailing nils
|
||||
-- 3) level The level of the error position relative to the called function
|
||||
-- returns; boolean; whether assertion passed
|
||||
|
||||
local assert = require('luassert.assert')
|
||||
local astate = require ('luassert.state')
|
||||
local util = require ('luassert.util')
|
||||
local s = require('luassert.say')
|
||||
|
||||
local function format(val)
|
||||
return astate.format_argument(val) or tostring(val)
|
||||
end
|
||||
|
||||
local function set_failure_message(state, message)
|
||||
if message ~= nil then
|
||||
state.failure_message = message
|
||||
end
|
||||
end
|
||||
|
||||
local function unique(state, arguments, level)
|
||||
local list = arguments[1]
|
||||
local deep
|
||||
local argcnt = arguments.n
|
||||
if type(arguments[2]) == "boolean" or (arguments[2] == nil and argcnt > 2) then
|
||||
deep = arguments[2]
|
||||
set_failure_message(state, arguments[3])
|
||||
else
|
||||
if type(arguments[3]) == "boolean" then
|
||||
deep = arguments[3]
|
||||
end
|
||||
set_failure_message(state, arguments[2])
|
||||
end
|
||||
for k,v in pairs(list) do
|
||||
for k2, v2 in pairs(list) do
|
||||
if k ~= k2 then
|
||||
if deep and util.deepcompare(v, v2, true) then
|
||||
return false
|
||||
else
|
||||
if v == v2 then
|
||||
return false
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
return true
|
||||
end
|
||||
|
||||
local function near(state, arguments, level)
|
||||
local level = (level or 1) + 1
|
||||
local argcnt = arguments.n
|
||||
assert(argcnt > 2, s("assertion.internal.argtolittle", { "near", 3, tostring(argcnt) }), level)
|
||||
local expected = tonumber(arguments[1])
|
||||
local actual = tonumber(arguments[2])
|
||||
local tolerance = tonumber(arguments[3])
|
||||
local numbertype = "number or object convertible to a number"
|
||||
assert(expected, s("assertion.internal.badargtype", { 1, "near", numbertype, format(arguments[1]) }), level)
|
||||
assert(actual, s("assertion.internal.badargtype", { 2, "near", numbertype, format(arguments[2]) }), level)
|
||||
assert(tolerance, s("assertion.internal.badargtype", { 3, "near", numbertype, format(arguments[3]) }), level)
|
||||
-- switch arguments for proper output message
|
||||
util.tinsert(arguments, 1, util.tremove(arguments, 2))
|
||||
arguments[3] = tolerance
|
||||
arguments.nofmt = arguments.nofmt or {}
|
||||
arguments.nofmt[3] = true
|
||||
set_failure_message(state, arguments[4])
|
||||
return (actual >= expected - tolerance and actual <= expected + tolerance)
|
||||
end
|
||||
|
||||
local function matches(state, arguments, level)
|
||||
local level = (level or 1) + 1
|
||||
local argcnt = arguments.n
|
||||
assert(argcnt > 1, s("assertion.internal.argtolittle", { "matches", 2, tostring(argcnt) }), level)
|
||||
local pattern = arguments[1]
|
||||
local actual = nil
|
||||
if util.hastostring(arguments[2]) or type(arguments[2]) == "number" then
|
||||
actual = tostring(arguments[2])
|
||||
end
|
||||
local err_message
|
||||
local init_arg_num = 3
|
||||
for i=3,argcnt,1 do
|
||||
if arguments[i] and type(arguments[i]) ~= "boolean" and not tonumber(arguments[i]) then
|
||||
if i == 3 then init_arg_num = init_arg_num + 1 end
|
||||
err_message = util.tremove(arguments, i)
|
||||
break
|
||||
end
|
||||
end
|
||||
local init = arguments[3]
|
||||
local plain = arguments[4]
|
||||
local stringtype = "string or object convertible to a string"
|
||||
assert(type(pattern) == "string", s("assertion.internal.badargtype", { 1, "matches", "string", type(arguments[1]) }), level)
|
||||
assert(actual, s("assertion.internal.badargtype", { 2, "matches", stringtype, format(arguments[2]) }), level)
|
||||
assert(init == nil or tonumber(init), s("assertion.internal.badargtype", { init_arg_num, "matches", "number", type(arguments[3]) }), level)
|
||||
-- switch arguments for proper output message
|
||||
util.tinsert(arguments, 1, util.tremove(arguments, 2))
|
||||
set_failure_message(state, err_message)
|
||||
local retargs
|
||||
local ok
|
||||
if plain then
|
||||
ok = (actual:find(pattern, init, plain) ~= nil)
|
||||
retargs = ok and { pattern } or {}
|
||||
else
|
||||
retargs = { actual:match(pattern, init) }
|
||||
ok = (retargs[1] ~= nil)
|
||||
end
|
||||
return ok, retargs
|
||||
end
|
||||
|
||||
local function equals(state, arguments, level)
|
||||
local level = (level or 1) + 1
|
||||
local argcnt = arguments.n
|
||||
assert(argcnt > 1, s("assertion.internal.argtolittle", { "equals", 2, tostring(argcnt) }), level)
|
||||
local result = arguments[1] == arguments[2]
|
||||
-- switch arguments for proper output message
|
||||
util.tinsert(arguments, 1, util.tremove(arguments, 2))
|
||||
set_failure_message(state, arguments[3])
|
||||
return result
|
||||
end
|
||||
|
||||
local function same(state, arguments, level)
|
||||
local level = (level or 1) + 1
|
||||
local argcnt = arguments.n
|
||||
assert(argcnt > 1, s("assertion.internal.argtolittle", { "same", 2, tostring(argcnt) }), level)
|
||||
if type(arguments[1]) == 'table' and type(arguments[2]) == 'table' then
|
||||
local result, crumbs = util.deepcompare(arguments[1], arguments[2], true)
|
||||
-- switch arguments for proper output message
|
||||
util.tinsert(arguments, 1, util.tremove(arguments, 2))
|
||||
arguments.fmtargs = arguments.fmtargs or {}
|
||||
arguments.fmtargs[1] = { crumbs = crumbs }
|
||||
arguments.fmtargs[2] = { crumbs = crumbs }
|
||||
set_failure_message(state, arguments[3])
|
||||
return result
|
||||
end
|
||||
local result = arguments[1] == arguments[2]
|
||||
-- switch arguments for proper output message
|
||||
util.tinsert(arguments, 1, util.tremove(arguments, 2))
|
||||
set_failure_message(state, arguments[3])
|
||||
return result
|
||||
end
|
||||
|
||||
local function truthy(state, arguments, level)
|
||||
set_failure_message(state, arguments[2])
|
||||
return arguments[1] ~= false and arguments[1] ~= nil
|
||||
end
|
||||
|
||||
local function falsy(state, arguments, level)
|
||||
return not truthy(state, arguments, level)
|
||||
end
|
||||
|
||||
local function has_error(state, arguments, level)
|
||||
local level = (level or 1) + 1
|
||||
local retargs = util.shallowcopy(arguments)
|
||||
local func = arguments[1]
|
||||
local err_expected = arguments[2]
|
||||
local failure_message = arguments[3]
|
||||
assert(util.callable(func), s("assertion.internal.badargtype", { 1, "error", "function or callable object", type(func) }), level)
|
||||
local ok, err_actual = pcall(func)
|
||||
if type(err_actual) == 'string' then
|
||||
-- remove 'path/to/file:line: ' from string
|
||||
err_actual = err_actual:gsub('^.-:%d+: ', '', 1)
|
||||
end
|
||||
retargs[1] = err_actual
|
||||
arguments.nofmt = {}
|
||||
arguments.n = 2
|
||||
arguments[1] = (ok and '(no error)' or err_actual)
|
||||
arguments[2] = (err_expected == nil and '(error)' or err_expected)
|
||||
arguments.nofmt[1] = ok
|
||||
arguments.nofmt[2] = (err_expected == nil)
|
||||
set_failure_message(state, failure_message)
|
||||
|
||||
if ok or err_expected == nil then
|
||||
return not ok, retargs
|
||||
end
|
||||
if type(err_expected) == 'string' then
|
||||
-- err_actual must be (convertible to) a string
|
||||
if util.hastostring(err_actual) then
|
||||
err_actual = tostring(err_actual)
|
||||
retargs[1] = err_actual
|
||||
end
|
||||
if type(err_actual) == 'string' then
|
||||
return err_expected == err_actual, retargs
|
||||
end
|
||||
elseif type(err_expected) == 'number' then
|
||||
if type(err_actual) == 'string' then
|
||||
return tostring(err_expected) == tostring(tonumber(err_actual)), retargs
|
||||
end
|
||||
end
|
||||
return same(state, {err_expected, err_actual, ["n"] = 2}), retargs
|
||||
end
|
||||
|
||||
local function error_matches(state, arguments, level)
|
||||
local level = (level or 1) + 1
|
||||
local retargs = util.shallowcopy(arguments)
|
||||
local argcnt = arguments.n
|
||||
local func = arguments[1]
|
||||
local pattern = arguments[2]
|
||||
assert(argcnt > 1, s("assertion.internal.argtolittle", { "error_matches", 2, tostring(argcnt) }), level)
|
||||
assert(util.callable(func), s("assertion.internal.badargtype", { 1, "error_matches", "function or callable object", type(func) }), level)
|
||||
assert(pattern == nil or type(pattern) == "string", s("assertion.internal.badargtype", { 2, "error", "string", type(pattern) }), level)
|
||||
|
||||
local failure_message
|
||||
local init_arg_num = 3
|
||||
for i=3,argcnt,1 do
|
||||
if arguments[i] and type(arguments[i]) ~= "boolean" and not tonumber(arguments[i]) then
|
||||
if i == 3 then init_arg_num = init_arg_num + 1 end
|
||||
failure_message = util.tremove(arguments, i)
|
||||
break
|
||||
end
|
||||
end
|
||||
local init = arguments[3]
|
||||
local plain = arguments[4]
|
||||
assert(init == nil or tonumber(init), s("assertion.internal.badargtype", { init_arg_num, "matches", "number", type(arguments[3]) }), level)
|
||||
|
||||
local ok, err_actual = pcall(func)
|
||||
if type(err_actual) == 'string' then
|
||||
-- remove 'path/to/file:line: ' from string
|
||||
err_actual = err_actual:gsub('^.-:%d+: ', '', 1)
|
||||
end
|
||||
retargs[1] = err_actual
|
||||
arguments.nofmt = {}
|
||||
arguments.n = 2
|
||||
arguments[1] = (ok and '(no error)' or err_actual)
|
||||
arguments[2] = pattern
|
||||
arguments.nofmt[1] = ok
|
||||
arguments.nofmt[2] = false
|
||||
set_failure_message(state, failure_message)
|
||||
|
||||
if ok then return not ok, retargs end
|
||||
if err_actual == nil and pattern == nil then
|
||||
return true, {}
|
||||
end
|
||||
|
||||
-- err_actual must be (convertible to) a string
|
||||
if util.hastostring(err_actual) then
|
||||
err_actual = tostring(err_actual)
|
||||
retargs[1] = err_actual
|
||||
end
|
||||
if type(err_actual) == 'string' then
|
||||
local ok
|
||||
local retargs_ok
|
||||
if plain then
|
||||
retargs_ok = { pattern }
|
||||
ok = (err_actual:find(pattern, init, plain) ~= nil)
|
||||
else
|
||||
retargs_ok = { err_actual:match(pattern, init) }
|
||||
ok = (retargs_ok[1] ~= nil)
|
||||
end
|
||||
if ok then retargs = retargs_ok end
|
||||
return ok, retargs
|
||||
end
|
||||
|
||||
return false, retargs
|
||||
end
|
||||
|
||||
local function is_true(state, arguments, level)
|
||||
util.tinsert(arguments, 2, true)
|
||||
set_failure_message(state, arguments[3])
|
||||
return arguments[1] == arguments[2]
|
||||
end
|
||||
|
||||
local function is_false(state, arguments, level)
|
||||
util.tinsert(arguments, 2, false)
|
||||
set_failure_message(state, arguments[3])
|
||||
return arguments[1] == arguments[2]
|
||||
end
|
||||
|
||||
local function is_type(state, arguments, level, etype)
|
||||
util.tinsert(arguments, 2, "type " .. etype)
|
||||
arguments.nofmt = arguments.nofmt or {}
|
||||
arguments.nofmt[2] = true
|
||||
set_failure_message(state, arguments[3])
|
||||
return arguments.n > 1 and type(arguments[1]) == etype
|
||||
end
|
||||
|
||||
local function returned_arguments(state, arguments, level)
|
||||
arguments[1] = tostring(arguments[1])
|
||||
arguments[2] = tostring(arguments.n - 1)
|
||||
arguments.nofmt = arguments.nofmt or {}
|
||||
arguments.nofmt[1] = true
|
||||
arguments.nofmt[2] = true
|
||||
if arguments.n < 2 then arguments.n = 2 end
|
||||
return arguments[1] == arguments[2]
|
||||
end
|
||||
|
||||
local function set_message(state, arguments, level)
|
||||
state.failure_message = arguments[1]
|
||||
end
|
||||
|
||||
local function is_boolean(state, arguments, level) return is_type(state, arguments, level, "boolean") end
|
||||
local function is_number(state, arguments, level) return is_type(state, arguments, level, "number") end
|
||||
local function is_string(state, arguments, level) return is_type(state, arguments, level, "string") end
|
||||
local function is_table(state, arguments, level) return is_type(state, arguments, level, "table") end
|
||||
local function is_nil(state, arguments, level) return is_type(state, arguments, level, "nil") end
|
||||
local function is_userdata(state, arguments, level) return is_type(state, arguments, level, "userdata") end
|
||||
local function is_function(state, arguments, level) return is_type(state, arguments, level, "function") end
|
||||
local function is_thread(state, arguments, level) return is_type(state, arguments, level, "thread") end
|
||||
|
||||
assert:register("modifier", "message", set_message)
|
||||
assert:register("assertion", "true", is_true, "assertion.same.positive", "assertion.same.negative")
|
||||
assert:register("assertion", "false", is_false, "assertion.same.positive", "assertion.same.negative")
|
||||
assert:register("assertion", "boolean", is_boolean, "assertion.same.positive", "assertion.same.negative")
|
||||
assert:register("assertion", "number", is_number, "assertion.same.positive", "assertion.same.negative")
|
||||
assert:register("assertion", "string", is_string, "assertion.same.positive", "assertion.same.negative")
|
||||
assert:register("assertion", "table", is_table, "assertion.same.positive", "assertion.same.negative")
|
||||
assert:register("assertion", "nil", is_nil, "assertion.same.positive", "assertion.same.negative")
|
||||
assert:register("assertion", "userdata", is_userdata, "assertion.same.positive", "assertion.same.negative")
|
||||
assert:register("assertion", "function", is_function, "assertion.same.positive", "assertion.same.negative")
|
||||
assert:register("assertion", "thread", is_thread, "assertion.same.positive", "assertion.same.negative")
|
||||
assert:register("assertion", "returned_arguments", returned_arguments, "assertion.returned_arguments.positive", "assertion.returned_arguments.negative")
|
||||
|
||||
assert:register("assertion", "same", same, "assertion.same.positive", "assertion.same.negative")
|
||||
assert:register("assertion", "matches", matches, "assertion.matches.positive", "assertion.matches.negative")
|
||||
assert:register("assertion", "match", matches, "assertion.matches.positive", "assertion.matches.negative")
|
||||
assert:register("assertion", "near", near, "assertion.near.positive", "assertion.near.negative")
|
||||
assert:register("assertion", "equals", equals, "assertion.equals.positive", "assertion.equals.negative")
|
||||
assert:register("assertion", "equal", equals, "assertion.equals.positive", "assertion.equals.negative")
|
||||
assert:register("assertion", "unique", unique, "assertion.unique.positive", "assertion.unique.negative")
|
||||
assert:register("assertion", "error", has_error, "assertion.error.positive", "assertion.error.negative")
|
||||
assert:register("assertion", "errors", has_error, "assertion.error.positive", "assertion.error.negative")
|
||||
assert:register("assertion", "error_matches", error_matches, "assertion.error.positive", "assertion.error.negative")
|
||||
assert:register("assertion", "error_match", error_matches, "assertion.error.positive", "assertion.error.negative")
|
||||
assert:register("assertion", "matches_error", error_matches, "assertion.error.positive", "assertion.error.negative")
|
||||
assert:register("assertion", "match_error", error_matches, "assertion.error.positive", "assertion.error.negative")
|
||||
assert:register("assertion", "truthy", truthy, "assertion.truthy.positive", "assertion.truthy.negative")
|
||||
assert:register("assertion", "falsy", falsy, "assertion.falsy.positive", "assertion.falsy.negative")
|
||||
@ -0,0 +1,9 @@
|
||||
-- no longer needed, only for backward compatibility
|
||||
local unpack = require ("luassert.util").unpack
|
||||
|
||||
return {
|
||||
unpack = function(...)
|
||||
print(debug.traceback("WARN: calling deprecated function 'luassert.compatibility.unpack' use 'luassert.util.unpack' instead"))
|
||||
return unpack(...)
|
||||
end
|
||||
}
|
||||
@ -0,0 +1,33 @@
|
||||
local format = function(str)
|
||||
if type(str) ~= "string" then
|
||||
return nil
|
||||
end
|
||||
local result = "Binary string length; " .. tostring(#str) .. " bytes\n"
|
||||
local i = 1
|
||||
local hex = ""
|
||||
local chr = ""
|
||||
while i <= #str do
|
||||
local byte = str:byte(i)
|
||||
hex = string.format("%s%2x ", hex, byte)
|
||||
if byte < 32 then
|
||||
byte = string.byte(".")
|
||||
end
|
||||
chr = chr .. string.char(byte)
|
||||
if math.floor(i / 16) == i / 16 or i == #str then
|
||||
-- reached end of line
|
||||
hex = hex .. string.rep(" ", 16 * 3 - #hex)
|
||||
chr = chr .. string.rep(" ", 16 - #chr)
|
||||
|
||||
result = result .. hex:sub(1, 8 * 3) .. " " .. hex:sub(8 * 3 + 1, -1) .. " " .. chr:sub(1, 8) .. " " ..
|
||||
chr:sub(9, -1) .. "\n"
|
||||
|
||||
hex = ""
|
||||
chr = ""
|
||||
end
|
||||
i = i + 1
|
||||
end
|
||||
return result
|
||||
end
|
||||
|
||||
return format
|
||||
|
||||
@ -0,0 +1,258 @@
|
||||
-- module will not return anything, only register formatters with the main assert engine
|
||||
local assert = require('luassert.assert')
|
||||
local match = require('luassert.match')
|
||||
local util = require('luassert.util')
|
||||
|
||||
local colors = setmetatable({
|
||||
none = function(c)
|
||||
return c
|
||||
end,
|
||||
}, {
|
||||
__index = function(self, key)
|
||||
local ok, term = pcall(require, 'term')
|
||||
local isatty = io.type(io.stdout) == 'file' and ok and term.isatty(io.stdout)
|
||||
if not ok or not isatty or not term.colors then
|
||||
return function(c)
|
||||
return c
|
||||
end
|
||||
end
|
||||
return function(c)
|
||||
for token in key:gmatch("[^%.]+") do
|
||||
c = term.colors[token](c)
|
||||
end
|
||||
return c
|
||||
end
|
||||
end,
|
||||
})
|
||||
|
||||
local function fmt_string(arg)
|
||||
if type(arg) == "string" then
|
||||
return string.format("(string) '%s'", arg)
|
||||
end
|
||||
end
|
||||
|
||||
-- A version of tostring which formats numbers more precisely.
|
||||
local function tostr(arg)
|
||||
if type(arg) ~= "number" then
|
||||
return tostring(arg)
|
||||
end
|
||||
|
||||
if arg ~= arg then
|
||||
return "NaN"
|
||||
elseif arg == 1 / 0 then
|
||||
return "Inf"
|
||||
elseif arg == -1 / 0 then
|
||||
return "-Inf"
|
||||
end
|
||||
|
||||
local str = string.format("%.20g", arg)
|
||||
|
||||
if math.type and math.type(arg) == "float" and not str:find("[%.,]") then
|
||||
-- Number is a float but looks like an integer.
|
||||
-- Insert ".0" after first run of digits.
|
||||
str = str:gsub("%d+", "%0.0", 1)
|
||||
end
|
||||
|
||||
return str
|
||||
end
|
||||
|
||||
local function fmt_number(arg)
|
||||
if type(arg) == "number" then
|
||||
return string.format("(number) %s", tostr(arg))
|
||||
end
|
||||
end
|
||||
|
||||
local function fmt_boolean(arg)
|
||||
if type(arg) == "boolean" then
|
||||
return string.format("(boolean) %s", tostring(arg))
|
||||
end
|
||||
end
|
||||
|
||||
local function fmt_nil(arg)
|
||||
if type(arg) == "nil" then
|
||||
return "(nil)"
|
||||
end
|
||||
end
|
||||
|
||||
local type_priorities = {
|
||||
number = 1,
|
||||
boolean = 2,
|
||||
string = 3,
|
||||
table = 4,
|
||||
["function"] = 5,
|
||||
userdata = 6,
|
||||
thread = 7,
|
||||
}
|
||||
|
||||
local function is_in_array_part(key, length)
|
||||
return type(key) == "number" and 1 <= key and key <= length and math.floor(key) == key
|
||||
end
|
||||
|
||||
local function get_sorted_keys(t)
|
||||
local keys = {}
|
||||
local nkeys = 0
|
||||
|
||||
for key in pairs(t) do
|
||||
nkeys = nkeys + 1
|
||||
keys[nkeys] = key
|
||||
end
|
||||
|
||||
local length = #t
|
||||
|
||||
local function key_comparator(key1, key2)
|
||||
local type1, type2 = type(key1), type(key2)
|
||||
local priority1 = is_in_array_part(key1, length) and 0 or type_priorities[type1] or 8
|
||||
local priority2 = is_in_array_part(key2, length) and 0 or type_priorities[type2] or 8
|
||||
|
||||
if priority1 == priority2 then
|
||||
if type1 == "string" or type1 == "number" then
|
||||
return key1 < key2
|
||||
elseif type1 == "boolean" then
|
||||
return key1 -- put true before false
|
||||
end
|
||||
else
|
||||
return priority1 < priority2
|
||||
end
|
||||
end
|
||||
|
||||
table.sort(keys, key_comparator)
|
||||
return keys, nkeys
|
||||
end
|
||||
|
||||
local function fmt_table(arg, fmtargs)
|
||||
if type(arg) ~= "table" then
|
||||
return
|
||||
end
|
||||
|
||||
local tmax = assert:get_parameter("TableFormatLevel")
|
||||
local showrec = assert:get_parameter("TableFormatShowRecursion")
|
||||
local errchar = assert:get_parameter("TableErrorHighlightCharacter") or ""
|
||||
local errcolor = assert:get_parameter("TableErrorHighlightColor") or "none"
|
||||
local crumbs = fmtargs and fmtargs.crumbs or {}
|
||||
local cache = {}
|
||||
local type_desc
|
||||
|
||||
if getmetatable(arg) == nil then
|
||||
type_desc = "(" .. tostring(arg) .. ") "
|
||||
elseif not pcall(setmetatable, arg, getmetatable(arg)) then
|
||||
-- cannot set same metatable, so it is protected, skip id
|
||||
type_desc = "(table) "
|
||||
else
|
||||
-- unprotected metatable, temporary remove the mt
|
||||
local mt = getmetatable(arg)
|
||||
setmetatable(arg, nil)
|
||||
type_desc = "(" .. tostring(arg) .. ") "
|
||||
setmetatable(arg, mt)
|
||||
end
|
||||
|
||||
local function ft(t, l, with_crumbs)
|
||||
if showrec and cache[t] and cache[t] > 0 then
|
||||
return "{ ... recursive }"
|
||||
end
|
||||
|
||||
if next(t) == nil then
|
||||
return "{ }"
|
||||
end
|
||||
|
||||
if l > tmax and tmax >= 0 then
|
||||
return "{ ... more }"
|
||||
end
|
||||
|
||||
local result = "{"
|
||||
local keys, nkeys = get_sorted_keys(t)
|
||||
|
||||
cache[t] = (cache[t] or 0) + 1
|
||||
local crumb = crumbs[#crumbs - l + 1]
|
||||
|
||||
for i = 1, nkeys do
|
||||
local k = keys[i]
|
||||
local v = t[k]
|
||||
local use_crumbs = with_crumbs and k == crumb
|
||||
|
||||
if type(v) == "table" then
|
||||
v = ft(v, l + 1, use_crumbs)
|
||||
elseif type(v) == "string" then
|
||||
v = "'" .. v .. "'"
|
||||
end
|
||||
|
||||
local ch = use_crumbs and errchar or ""
|
||||
local indent = string.rep(" ", l * 2 - ch:len())
|
||||
local mark = (ch:len() == 0 and "" or colors[errcolor](ch))
|
||||
result = result .. string.format("\n%s%s[%s] = %s", indent, mark, tostr(k), tostr(v))
|
||||
end
|
||||
|
||||
cache[t] = cache[t] - 1
|
||||
|
||||
return result .. " }"
|
||||
end
|
||||
|
||||
return type_desc .. ft(arg, 1, true)
|
||||
end
|
||||
|
||||
local function fmt_function(arg)
|
||||
if type(arg) == "function" then
|
||||
local debug_info = debug.getinfo(arg)
|
||||
return string.format("%s @ line %s in %s", tostring(arg), tostring(debug_info.linedefined),
|
||||
tostring(debug_info.source))
|
||||
end
|
||||
end
|
||||
|
||||
local function fmt_userdata(arg)
|
||||
if type(arg) == "userdata" then
|
||||
return string.format("(userdata) '%s'", tostring(arg))
|
||||
end
|
||||
end
|
||||
|
||||
local function fmt_thread(arg)
|
||||
if type(arg) == "thread" then
|
||||
return string.format("(thread) '%s'", tostring(arg))
|
||||
end
|
||||
end
|
||||
|
||||
local function fmt_matcher(arg)
|
||||
if not match.is_matcher(arg) then
|
||||
return
|
||||
end
|
||||
local not_inverted = {
|
||||
[true] = "is.",
|
||||
[false] = "no.",
|
||||
}
|
||||
local args = {}
|
||||
for idx = 1, arg.arguments.n do
|
||||
table.insert(args, assert:format({
|
||||
arg.arguments[idx],
|
||||
n = 1,
|
||||
})[1])
|
||||
end
|
||||
return string.format("(matcher) %s%s(%s)", not_inverted[arg.mod], tostring(arg.name), table.concat(args, ", "))
|
||||
end
|
||||
|
||||
local function fmt_arglist(arglist)
|
||||
if not util.is_arglist(arglist) then
|
||||
return
|
||||
end
|
||||
local formatted_vals = {}
|
||||
for idx = 1, arglist.n do
|
||||
table.insert(formatted_vals, assert:format({
|
||||
arglist[idx],
|
||||
n = 1,
|
||||
})[1])
|
||||
end
|
||||
return "(values list) (" .. table.concat(formatted_vals, ", ") .. ")"
|
||||
end
|
||||
|
||||
assert:add_formatter(fmt_string)
|
||||
assert:add_formatter(fmt_number)
|
||||
assert:add_formatter(fmt_boolean)
|
||||
assert:add_formatter(fmt_nil)
|
||||
assert:add_formatter(fmt_table)
|
||||
assert:add_formatter(fmt_function)
|
||||
assert:add_formatter(fmt_userdata)
|
||||
assert:add_formatter(fmt_thread)
|
||||
assert:add_formatter(fmt_matcher)
|
||||
assert:add_formatter(fmt_arglist)
|
||||
-- Set default table display depth for table formatter
|
||||
assert:set_parameter("TableFormatLevel", 3)
|
||||
assert:set_parameter("TableFormatShowRecursion", false)
|
||||
assert:set_parameter("TableErrorHighlightCharacter", "*")
|
||||
assert:set_parameter("TableErrorHighlightColor", "none")
|
||||
@ -0,0 +1,18 @@
|
||||
local assert = require('luassert.assert')
|
||||
|
||||
assert._COPYRIGHT = "Copyright (c) 2018 Olivine Labs, LLC."
|
||||
assert._DESCRIPTION =
|
||||
"Extends Lua's built-in assertions to provide additional tests and the ability to create your own."
|
||||
assert._VERSION = "Luassert 1.8.0"
|
||||
|
||||
-- load basic asserts
|
||||
require('luassert.assertions')
|
||||
require('luassert.modifiers')
|
||||
require('luassert.array')
|
||||
require('luassert.matchers')
|
||||
require('luassert.formatters')
|
||||
|
||||
-- load default language
|
||||
require('luassert.languages.en')
|
||||
|
||||
return assert
|
||||
@ -0,0 +1,52 @@
|
||||
local s = require('luassert.say')
|
||||
|
||||
s:set_namespace('en')
|
||||
|
||||
s:set("assertion.same.positive", "Expected objects to be the same.\nPassed in:\n%s\nExpected:\n%s")
|
||||
s:set("assertion.same.negative", "Expected objects to not be the same.\nPassed in:\n%s\nDid not expect:\n%s")
|
||||
|
||||
s:set("assertion.equals.positive", "Expected objects to be equal.\nPassed in:\n%s\nExpected:\n%s")
|
||||
s:set("assertion.equals.negative", "Expected objects to not be equal.\nPassed in:\n%s\nDid not expect:\n%s")
|
||||
|
||||
s:set("assertion.near.positive", "Expected values to be near.\nPassed in:\n%s\nExpected:\n%s +/- %s")
|
||||
s:set("assertion.near.negative", "Expected values to not be near.\nPassed in:\n%s\nDid not expect:\n%s +/- %s")
|
||||
|
||||
s:set("assertion.matches.positive", "Expected strings to match.\nPassed in:\n%s\nExpected:\n%s")
|
||||
s:set("assertion.matches.negative", "Expected strings not to match.\nPassed in:\n%s\nDid not expect:\n%s")
|
||||
|
||||
s:set("assertion.unique.positive", "Expected object to be unique:\n%s")
|
||||
s:set("assertion.unique.negative", "Expected object to not be unique:\n%s")
|
||||
|
||||
s:set("assertion.error.positive", "Expected a different error.\nCaught:\n%s\nExpected:\n%s")
|
||||
s:set("assertion.error.negative", "Expected no error, but caught:\n%s")
|
||||
|
||||
s:set("assertion.truthy.positive", "Expected to be truthy, but value was:\n%s")
|
||||
s:set("assertion.truthy.negative", "Expected to not be truthy, but value was:\n%s")
|
||||
|
||||
s:set("assertion.falsy.positive", "Expected to be falsy, but value was:\n%s")
|
||||
s:set("assertion.falsy.negative", "Expected to not be falsy, but value was:\n%s")
|
||||
|
||||
s:set("assertion.called.positive", "Expected to be called %s time(s), but was called %s time(s)")
|
||||
s:set("assertion.called.negative", "Expected not to be called exactly %s time(s), but it was.")
|
||||
|
||||
s:set("assertion.called_at_least.positive", "Expected to be called at least %s time(s), but was called %s time(s)")
|
||||
s:set("assertion.called_at_most.positive", "Expected to be called at most %s time(s), but was called %s time(s)")
|
||||
s:set("assertion.called_more_than.positive", "Expected to be called more than %s time(s), but was called %s time(s)")
|
||||
s:set("assertion.called_less_than.positive", "Expected to be called less than %s time(s), but was called %s time(s)")
|
||||
|
||||
s:set("assertion.called_with.positive",
|
||||
"Function was never called with matching arguments.\nCalled with (last call if any):\n%s\nExpected:\n%s")
|
||||
s:set("assertion.called_with.negative",
|
||||
"Function was called with matching arguments at least once.\nCalled with (last matching call):\n%s\nDid not expect:\n%s")
|
||||
|
||||
s:set("assertion.returned_with.positive",
|
||||
"Function never returned matching arguments.\nReturned (last call if any):\n%s\nExpected:\n%s")
|
||||
s:set("assertion.returned_with.negative",
|
||||
"Function returned matching arguments at least once.\nReturned (last matching call):\n%s\nDid not expect:\n%s")
|
||||
|
||||
s:set("assertion.returned_arguments.positive", "Expected to be called with %s argument(s), but was called with %s")
|
||||
s:set("assertion.returned_arguments.negative", "Expected not to be called with %s argument(s), but was called with %s")
|
||||
|
||||
-- errors
|
||||
s:set("assertion.internal.argtolittle", "the '%s' function requires a minimum of %s arguments, got: %s")
|
||||
s:set("assertion.internal.badargtype", "bad argument #%s to '%s' (%s expected, got %s)")
|
||||
@ -0,0 +1,31 @@
|
||||
local s = require('luassert.say')
|
||||
|
||||
s:set_namespace('zh')
|
||||
|
||||
s:set("assertion.same.positive", "希望对象应该相同.\n实际值:\n%s\n希望值:\n%s")
|
||||
s:set("assertion.same.negative", "希望对象应该不相同.\n实际值:\n%s\n不希望与:\n%s\n相同")
|
||||
|
||||
s:set("assertion.equals.positive", "希望对象应该相等.\n实际值:\n%s\n希望值:\n%s")
|
||||
s:set("assertion.equals.negative", "希望对象应该不相等.\n实际值:\n%s\n不希望等于:\n%s")
|
||||
|
||||
s:set("assertion.unique.positive", "希望对象是唯一的:\n%s")
|
||||
s:set("assertion.unique.negative", "希望对象不是唯一的:\n%s")
|
||||
|
||||
s:set("assertion.error.positive", "希望有错误被抛出.")
|
||||
s:set("assertion.error.negative", "希望没有错误被抛出.\n%s")
|
||||
|
||||
s:set("assertion.truthy.positive", "希望结果为真,但是实际为:\n%s")
|
||||
s:set("assertion.truthy.negative", "希望结果不为真,但是实际为:\n%s")
|
||||
|
||||
s:set("assertion.falsy.positive", "希望结果为假,但是实际为:\n%s")
|
||||
s:set("assertion.falsy.negative", "希望结果不为假,但是实际为:\n%s")
|
||||
|
||||
s:set("assertion.called.positive", "希望被调用%s次, 但实际被调用了%s次")
|
||||
s:set("assertion.called.negative", "不希望正好被调用%s次, 但是正好被调用了那么多次.")
|
||||
|
||||
s:set("assertion.called_with.positive", "希望没有参数的调用函数")
|
||||
s:set("assertion.called_with.negative", "希望有参数的调用函数")
|
||||
|
||||
-- errors
|
||||
s:set("assertion.internal.argtolittle", "函数'%s'需要最少%s个参数, 实际有%s个参数\n")
|
||||
s:set("assertion.internal.badargtype", "bad argument #%s: 函数'%s'需要一个%s作为参数, 实际为: %s\n")
|
||||
@ -0,0 +1,79 @@
|
||||
local namespace = require 'luassert.namespaces'
|
||||
local util = require 'luassert.util'
|
||||
|
||||
local matcher_mt = {
|
||||
__call = function(self, value)
|
||||
return self.callback(value) == self.mod
|
||||
end,
|
||||
}
|
||||
|
||||
local state_mt = {
|
||||
__call = function(self, ...)
|
||||
local keys = util.extract_keys("matcher", self.tokens)
|
||||
self.tokens = {}
|
||||
|
||||
local matcher
|
||||
|
||||
for _, key in ipairs(keys) do
|
||||
matcher = namespace.matcher[key] or matcher
|
||||
end
|
||||
|
||||
if matcher then
|
||||
for _, key in ipairs(keys) do
|
||||
if namespace.modifier[key] then
|
||||
namespace.modifier[key].callback(self)
|
||||
end
|
||||
end
|
||||
|
||||
local arguments = util.make_arglist(...)
|
||||
local matches = matcher.callback(self, arguments, util.errorlevel())
|
||||
return setmetatable({
|
||||
name = matcher.name,
|
||||
mod = self.mod,
|
||||
callback = matches,
|
||||
arguments = arguments,
|
||||
}, matcher_mt)
|
||||
else
|
||||
local arguments = util.make_arglist(...)
|
||||
|
||||
for _, key in ipairs(keys) do
|
||||
if namespace.modifier[key] then
|
||||
namespace.modifier[key].callback(self, arguments, util.errorlevel())
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
return self
|
||||
end,
|
||||
|
||||
__index = function(self, key)
|
||||
for token in key:lower():gmatch('[^_]+') do
|
||||
table.insert(self.tokens, token)
|
||||
end
|
||||
|
||||
return self
|
||||
end
|
||||
}
|
||||
|
||||
local match = {
|
||||
_ = setmetatable({mod=true, callback=function() return true end}, matcher_mt),
|
||||
|
||||
state = function() return setmetatable({mod=true, tokens={}}, state_mt) end,
|
||||
|
||||
is_matcher = function(object)
|
||||
return type(object) == "table" and getmetatable(object) == matcher_mt
|
||||
end,
|
||||
|
||||
is_ref_matcher = function(object)
|
||||
local ismatcher = (type(object) == "table" and getmetatable(object) == matcher_mt)
|
||||
return ismatcher and object.name == "ref"
|
||||
end,
|
||||
}
|
||||
|
||||
local mt = {
|
||||
__index = function(self, key)
|
||||
return rawget(self, key) or self.state()[key]
|
||||
end,
|
||||
}
|
||||
|
||||
return setmetatable(match, mt)
|
||||
@ -0,0 +1,64 @@
|
||||
local assert = require('luassert.assert')
|
||||
local match = require('luassert.match')
|
||||
local s = require('luassert.say')
|
||||
|
||||
local function none(state, arguments, level)
|
||||
local level = (level or 1) + 1
|
||||
local argcnt = arguments.n
|
||||
assert(argcnt > 0, s("assertion.internal.argtolittle", {"none", 1, tostring(argcnt)}), level)
|
||||
for i = 1, argcnt do
|
||||
assert(match.is_matcher(arguments[i]),
|
||||
s("assertion.internal.badargtype", {1, "none", "matcher", type(arguments[i])}), level)
|
||||
end
|
||||
|
||||
return function(value)
|
||||
for _, matcher in ipairs(arguments) do
|
||||
if matcher(value) then
|
||||
return false
|
||||
end
|
||||
end
|
||||
return true
|
||||
end
|
||||
end
|
||||
|
||||
local function any(state, arguments, level)
|
||||
local level = (level or 1) + 1
|
||||
local argcnt = arguments.n
|
||||
assert(argcnt > 0, s("assertion.internal.argtolittle", {"any", 1, tostring(argcnt)}), level)
|
||||
for i = 1, argcnt do
|
||||
assert(match.is_matcher(arguments[i]),
|
||||
s("assertion.internal.badargtype", {1, "any", "matcher", type(arguments[i])}), level)
|
||||
end
|
||||
|
||||
return function(value)
|
||||
for _, matcher in ipairs(arguments) do
|
||||
if matcher(value) then
|
||||
return true
|
||||
end
|
||||
end
|
||||
return false
|
||||
end
|
||||
end
|
||||
|
||||
local function all(state, arguments, level)
|
||||
local level = (level or 1) + 1
|
||||
local argcnt = arguments.n
|
||||
assert(argcnt > 0, s("assertion.internal.argtolittle", {"all", 1, tostring(argcnt)}), level)
|
||||
for i = 1, argcnt do
|
||||
assert(match.is_matcher(arguments[i]),
|
||||
s("assertion.internal.badargtype", {1, "all", "matcher", type(arguments[i])}), level)
|
||||
end
|
||||
|
||||
return function(value)
|
||||
for _, matcher in ipairs(arguments) do
|
||||
if not matcher(value) then
|
||||
return false
|
||||
end
|
||||
end
|
||||
return true
|
||||
end
|
||||
end
|
||||
|
||||
assert:register("matcher", "none_of", none)
|
||||
assert:register("matcher", "any_of", any)
|
||||
assert:register("matcher", "all_of", all)
|
||||
@ -0,0 +1,173 @@
|
||||
-- module will return the list of matchers, and registers matchers with the main assert engine
|
||||
|
||||
-- matchers take 1 parameters;
|
||||
-- 1) state
|
||||
-- 2) arguments list. The list has a member 'n' with the argument count to check for trailing nils
|
||||
-- 3) level The level of the error position relative to the called function
|
||||
-- returns; function (or callable object); a function that, given an argument, returns a boolean
|
||||
|
||||
local assert = require('luassert.assert')
|
||||
local astate = require('luassert.state')
|
||||
local util = require('luassert.util')
|
||||
local s = require('luassert.say')
|
||||
|
||||
local function format(val)
|
||||
return astate.format_argument(val) or tostring(val)
|
||||
end
|
||||
|
||||
local function unique(state, arguments, level)
|
||||
local deep = arguments[1]
|
||||
return function(value)
|
||||
local list = value
|
||||
for k,v in pairs(list) do
|
||||
for k2, v2 in pairs(list) do
|
||||
if k ~= k2 then
|
||||
if deep and util.deepcompare(v, v2, true) then
|
||||
return false
|
||||
else
|
||||
if v == v2 then
|
||||
return false
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
return true
|
||||
end
|
||||
end
|
||||
|
||||
local function near(state, arguments, level)
|
||||
local level = (level or 1) + 1
|
||||
local argcnt = arguments.n
|
||||
assert(argcnt > 1, s("assertion.internal.argtolittle", { "near", 2, tostring(argcnt) }), level)
|
||||
local expected = tonumber(arguments[1])
|
||||
local tolerance = tonumber(arguments[2])
|
||||
local numbertype = "number or object convertible to a number"
|
||||
assert(expected, s("assertion.internal.badargtype", { 1, "near", numbertype, format(arguments[1]) }), level)
|
||||
assert(tolerance, s("assertion.internal.badargtype", { 2, "near", numbertype, format(arguments[2]) }), level)
|
||||
|
||||
return function(value)
|
||||
local actual = tonumber(value)
|
||||
if not actual then return false end
|
||||
return (actual >= expected - tolerance and actual <= expected + tolerance)
|
||||
end
|
||||
end
|
||||
|
||||
local function matches(state, arguments, level)
|
||||
local level = (level or 1) + 1
|
||||
local argcnt = arguments.n
|
||||
assert(argcnt > 0, s("assertion.internal.argtolittle", { "matches", 1, tostring(argcnt) }), level)
|
||||
local pattern = arguments[1]
|
||||
local init = arguments[2]
|
||||
local plain = arguments[3]
|
||||
assert(type(pattern) == "string", s("assertion.internal.badargtype", { 1, "matches", "string", type(arguments[1]) }), level)
|
||||
assert(init == nil or tonumber(init), s("assertion.internal.badargtype", { 2, "matches", "number", type(arguments[2]) }), level)
|
||||
|
||||
return function(value)
|
||||
local actualtype = type(value)
|
||||
local actual = nil
|
||||
if actualtype == "string" or actualtype == "number" or
|
||||
actualtype == "table" and (getmetatable(value) or {}).__tostring then
|
||||
actual = tostring(value)
|
||||
end
|
||||
if not actual then return false end
|
||||
return (actual:find(pattern, init, plain) ~= nil)
|
||||
end
|
||||
end
|
||||
|
||||
local function equals(state, arguments, level)
|
||||
local level = (level or 1) + 1
|
||||
local argcnt = arguments.n
|
||||
assert(argcnt > 0, s("assertion.internal.argtolittle", { "equals", 1, tostring(argcnt) }), level)
|
||||
return function(value)
|
||||
return value == arguments[1]
|
||||
end
|
||||
end
|
||||
|
||||
local function same(state, arguments, level)
|
||||
local level = (level or 1) + 1
|
||||
local argcnt = arguments.n
|
||||
assert(argcnt > 0, s("assertion.internal.argtolittle", { "same", 1, tostring(argcnt) }), level)
|
||||
return function(value)
|
||||
if type(value) == 'table' and type(arguments[1]) == 'table' then
|
||||
local result = util.deepcompare(value, arguments[1], true)
|
||||
return result
|
||||
end
|
||||
return value == arguments[1]
|
||||
end
|
||||
end
|
||||
|
||||
local function ref(state, arguments, level)
|
||||
local level = (level or 1) + 1
|
||||
local argcnt = arguments.n
|
||||
local argtype = type(arguments[1])
|
||||
local isobject = (argtype == "table" or argtype == "function" or argtype == "thread" or argtype == "userdata")
|
||||
assert(argcnt > 0, s("assertion.internal.argtolittle", { "ref", 1, tostring(argcnt) }), level)
|
||||
assert(isobject, s("assertion.internal.badargtype", { 1, "ref", "object", argtype }), level)
|
||||
return function(value)
|
||||
return value == arguments[1]
|
||||
end
|
||||
end
|
||||
|
||||
local function is_true(state, arguments, level)
|
||||
return function(value)
|
||||
return value == true
|
||||
end
|
||||
end
|
||||
|
||||
local function is_false(state, arguments, level)
|
||||
return function(value)
|
||||
return value == false
|
||||
end
|
||||
end
|
||||
|
||||
local function truthy(state, arguments, level)
|
||||
return function(value)
|
||||
return value ~= false and value ~= nil
|
||||
end
|
||||
end
|
||||
|
||||
local function falsy(state, arguments, level)
|
||||
local is_truthy = truthy(state, arguments, level)
|
||||
return function(value)
|
||||
return not is_truthy(value)
|
||||
end
|
||||
end
|
||||
|
||||
local function is_type(state, arguments, level, etype)
|
||||
return function(value)
|
||||
return type(value) == etype
|
||||
end
|
||||
end
|
||||
|
||||
local function is_nil(state, arguments, level) return is_type(state, arguments, level, "nil") end
|
||||
local function is_boolean(state, arguments, level) return is_type(state, arguments, level, "boolean") end
|
||||
local function is_number(state, arguments, level) return is_type(state, arguments, level, "number") end
|
||||
local function is_string(state, arguments, level) return is_type(state, arguments, level, "string") end
|
||||
local function is_table(state, arguments, level) return is_type(state, arguments, level, "table") end
|
||||
local function is_function(state, arguments, level) return is_type(state, arguments, level, "function") end
|
||||
local function is_userdata(state, arguments, level) return is_type(state, arguments, level, "userdata") end
|
||||
local function is_thread(state, arguments, level) return is_type(state, arguments, level, "thread") end
|
||||
|
||||
assert:register("matcher", "true", is_true)
|
||||
assert:register("matcher", "false", is_false)
|
||||
|
||||
assert:register("matcher", "nil", is_nil)
|
||||
assert:register("matcher", "boolean", is_boolean)
|
||||
assert:register("matcher", "number", is_number)
|
||||
assert:register("matcher", "string", is_string)
|
||||
assert:register("matcher", "table", is_table)
|
||||
assert:register("matcher", "function", is_function)
|
||||
assert:register("matcher", "userdata", is_userdata)
|
||||
assert:register("matcher", "thread", is_thread)
|
||||
|
||||
assert:register("matcher", "ref", ref)
|
||||
assert:register("matcher", "same", same)
|
||||
assert:register("matcher", "matches", matches)
|
||||
assert:register("matcher", "match", matches)
|
||||
assert:register("matcher", "near", near)
|
||||
assert:register("matcher", "equals", equals)
|
||||
assert:register("matcher", "equal", equals)
|
||||
assert:register("matcher", "unique", unique)
|
||||
assert:register("matcher", "truthy", truthy)
|
||||
assert:register("matcher", "falsy", falsy)
|
||||
@ -0,0 +1,3 @@
|
||||
-- load basic machers
|
||||
require('luassert.matchers.core')
|
||||
require('luassert.matchers.composite')
|
||||
@ -0,0 +1,65 @@
|
||||
-- module will return a mock module table, and will not register any assertions
|
||||
local spy = require 'luassert.spy'
|
||||
local stub = require 'luassert.stub'
|
||||
|
||||
local function mock_apply(object, action)
|
||||
if type(object) ~= "table" then
|
||||
return
|
||||
end
|
||||
if spy.is_spy(object) then
|
||||
return object[action](object)
|
||||
end
|
||||
for k, v in pairs(object) do
|
||||
mock_apply(v, action)
|
||||
end
|
||||
return object
|
||||
end
|
||||
|
||||
local mock
|
||||
mock = {
|
||||
new = function(object, dostub, func, self, key)
|
||||
local visited = {}
|
||||
local function do_mock(object, self, key)
|
||||
local mock_handlers = {
|
||||
["table"] = function()
|
||||
if spy.is_spy(object) or visited[object] then
|
||||
return
|
||||
end
|
||||
visited[object] = true
|
||||
for k, v in pairs(object) do
|
||||
object[k] = do_mock(v, object, k)
|
||||
end
|
||||
return object
|
||||
end,
|
||||
["function"] = function()
|
||||
if dostub then
|
||||
return stub(self, key, func)
|
||||
elseif self == nil then
|
||||
return spy.new(object)
|
||||
else
|
||||
return spy.on(self, key)
|
||||
end
|
||||
end,
|
||||
}
|
||||
local handler = mock_handlers[type(object)]
|
||||
return handler and handler() or object
|
||||
end
|
||||
return do_mock(object, self, key)
|
||||
end,
|
||||
|
||||
clear = function(object)
|
||||
return mock_apply(object, "clear")
|
||||
end,
|
||||
|
||||
revert = function(object)
|
||||
return mock_apply(object, "revert")
|
||||
end,
|
||||
}
|
||||
|
||||
return setmetatable(mock, {
|
||||
__call = function(self, ...)
|
||||
-- mock originally was a function only. Now that it is a module table
|
||||
-- the __call method is required for backward compatibility
|
||||
return mock.new(...)
|
||||
end,
|
||||
})
|
||||
@ -0,0 +1,19 @@
|
||||
-- module will not return anything, only register assertions/modifiers with the main assert engine
|
||||
local assert = require('luassert.assert')
|
||||
|
||||
local function is(state)
|
||||
return state
|
||||
end
|
||||
|
||||
local function is_not(state)
|
||||
state.mod = not state.mod
|
||||
return state
|
||||
end
|
||||
|
||||
assert:register("modifier", "is", is)
|
||||
assert:register("modifier", "are", is)
|
||||
assert:register("modifier", "was", is)
|
||||
assert:register("modifier", "has", is)
|
||||
assert:register("modifier", "does", is)
|
||||
assert:register("modifier", "not", is_not)
|
||||
assert:register("modifier", "no", is_not)
|
||||
@ -0,0 +1,2 @@
|
||||
-- stores the list of namespaces
|
||||
return {}
|
||||
@ -0,0 +1,64 @@
|
||||
local unpack = table.unpack or unpack
|
||||
|
||||
local registry = {}
|
||||
local current_namespace
|
||||
local fallback_namespace
|
||||
|
||||
local s = {
|
||||
|
||||
_COPYRIGHT = "Copyright (c) 2012 Olivine Labs, LLC.",
|
||||
_DESCRIPTION = "A simple string key/value store for i18n or any other case where you want namespaced strings.",
|
||||
_VERSION = "Say 1.3",
|
||||
|
||||
set_namespace = function(self, namespace)
|
||||
current_namespace = namespace
|
||||
if not registry[current_namespace] then
|
||||
registry[current_namespace] = {}
|
||||
end
|
||||
end,
|
||||
|
||||
set_fallback = function(self, namespace)
|
||||
fallback_namespace = namespace
|
||||
if not registry[fallback_namespace] then
|
||||
registry[fallback_namespace] = {}
|
||||
end
|
||||
end,
|
||||
|
||||
set = function(self, key, value)
|
||||
registry[current_namespace][key] = value
|
||||
end,
|
||||
}
|
||||
|
||||
local __meta = {
|
||||
__call = function(self, key, vars)
|
||||
if vars ~= nil and type(vars) ~= "table" then
|
||||
error(("expected parameter table to be a table, got '%s'"):format(type(vars)), 2)
|
||||
end
|
||||
vars = vars or {}
|
||||
|
||||
local str = registry[current_namespace][key] or registry[fallback_namespace][key]
|
||||
|
||||
if str == nil then
|
||||
return nil
|
||||
end
|
||||
str = tostring(str)
|
||||
local strings = {}
|
||||
|
||||
for i = 1, vars.n or #vars do
|
||||
table.insert(strings, tostring(vars[i]))
|
||||
end
|
||||
|
||||
return #strings > 0 and str:format(unpack(strings)) or str
|
||||
end,
|
||||
|
||||
__index = function(self, key)
|
||||
return registry[key]
|
||||
end,
|
||||
}
|
||||
|
||||
s:set_fallback('en')
|
||||
s:set_namespace('en')
|
||||
|
||||
s._registry = registry
|
||||
|
||||
return setmetatable(s, __meta)
|
||||
@ -0,0 +1,215 @@
|
||||
-- module will return spy table, and register its assertions with the main assert engine
|
||||
local assert = require('luassert.assert')
|
||||
local util = require('luassert.util')
|
||||
|
||||
-- Spy metatable
|
||||
local spy_mt = {
|
||||
__call = function(self, ...)
|
||||
local arguments = util.make_arglist(...)
|
||||
table.insert(self.calls, util.copyargs(arguments))
|
||||
local function get_returns(...)
|
||||
local returnvals = util.make_arglist(...)
|
||||
table.insert(self.returnvals, util.copyargs(returnvals))
|
||||
return ...
|
||||
end
|
||||
return get_returns(self.callback(...))
|
||||
end,
|
||||
}
|
||||
|
||||
local spy -- must make local before defining table, because table contents refers to the table (recursion)
|
||||
spy = {
|
||||
new = function(callback)
|
||||
callback = callback or function()
|
||||
end
|
||||
if not util.callable(callback) then
|
||||
error("Cannot spy on type '" .. type(callback) .. "', only on functions or callable elements",
|
||||
util.errorlevel())
|
||||
end
|
||||
local s = setmetatable({
|
||||
calls = {},
|
||||
returnvals = {},
|
||||
callback = callback,
|
||||
|
||||
target_table = nil, -- these will be set when using 'spy.on'
|
||||
target_key = nil,
|
||||
|
||||
revert = function(self)
|
||||
if not self.reverted then
|
||||
if self.target_table and self.target_key then
|
||||
self.target_table[self.target_key] = self.callback
|
||||
end
|
||||
self.reverted = true
|
||||
end
|
||||
return self.callback
|
||||
end,
|
||||
|
||||
clear = function(self)
|
||||
self.calls = {}
|
||||
self.returnvals = {}
|
||||
return self
|
||||
end,
|
||||
|
||||
called = function(self, times, compare)
|
||||
if times or compare then
|
||||
local compare = compare or function(count, expected)
|
||||
return count == expected
|
||||
end
|
||||
return compare(#self.calls, times), #self.calls
|
||||
end
|
||||
|
||||
return (#self.calls > 0), #self.calls
|
||||
end,
|
||||
|
||||
called_with = function(self, args)
|
||||
local last_arglist = nil
|
||||
if #self.calls > 0 then
|
||||
last_arglist = self.calls[#self.calls].vals
|
||||
end
|
||||
local matching_arglists = util.matchargs(self.calls, args)
|
||||
if matching_arglists ~= nil then
|
||||
return true, matching_arglists.vals
|
||||
end
|
||||
return false, last_arglist
|
||||
end,
|
||||
|
||||
returned_with = function(self, args)
|
||||
local last_returnvallist = nil
|
||||
if #self.returnvals > 0 then
|
||||
last_returnvallist = self.returnvals[#self.returnvals].vals
|
||||
end
|
||||
local matching_returnvallists = util.matchargs(self.returnvals, args)
|
||||
if matching_returnvallists ~= nil then
|
||||
return true, matching_returnvallists.vals
|
||||
end
|
||||
return false, last_returnvallist
|
||||
end,
|
||||
}, spy_mt)
|
||||
assert:add_spy(s) -- register with the current state
|
||||
return s
|
||||
end,
|
||||
|
||||
is_spy = function(object)
|
||||
return type(object) == "table" and getmetatable(object) == spy_mt
|
||||
end,
|
||||
|
||||
on = function(target_table, target_key)
|
||||
local s = spy.new(target_table[target_key])
|
||||
target_table[target_key] = s
|
||||
-- store original data
|
||||
s.target_table = target_table
|
||||
s.target_key = target_key
|
||||
|
||||
return s
|
||||
end,
|
||||
}
|
||||
|
||||
local function set_spy(state, arguments, level)
|
||||
state.payload = arguments[1]
|
||||
if arguments[2] ~= nil then
|
||||
state.failure_message = arguments[2]
|
||||
end
|
||||
end
|
||||
|
||||
local function returned_with(state, arguments, level)
|
||||
local level = (level or 1) + 1
|
||||
local payload = rawget(state, "payload")
|
||||
if payload and payload.returned_with then
|
||||
local assertion_holds, matching_or_last_returnvallist = state.payload:returned_with(arguments)
|
||||
local expected_returnvallist = util.shallowcopy(arguments)
|
||||
util.cleararglist(arguments)
|
||||
util.tinsert(arguments, 1, matching_or_last_returnvallist)
|
||||
util.tinsert(arguments, 2, expected_returnvallist)
|
||||
return assertion_holds
|
||||
else
|
||||
error("'returned_with' must be chained after 'spy(aspy)'", level)
|
||||
end
|
||||
end
|
||||
|
||||
local function called_with(state, arguments, level)
|
||||
local level = (level or 1) + 1
|
||||
local payload = rawget(state, "payload")
|
||||
if payload and payload.called_with then
|
||||
local assertion_holds, matching_or_last_arglist = state.payload:called_with(arguments)
|
||||
local expected_arglist = util.shallowcopy(arguments)
|
||||
util.cleararglist(arguments)
|
||||
util.tinsert(arguments, 1, matching_or_last_arglist)
|
||||
util.tinsert(arguments, 2, expected_arglist)
|
||||
return assertion_holds
|
||||
else
|
||||
error("'called_with' must be chained after 'spy(aspy)'", level)
|
||||
end
|
||||
end
|
||||
|
||||
local function called(state, arguments, level, compare)
|
||||
local level = (level or 1) + 1
|
||||
local num_times = arguments[1]
|
||||
if not num_times and not state.mod then
|
||||
state.mod = true
|
||||
num_times = 0
|
||||
end
|
||||
local payload = rawget(state, "payload")
|
||||
if payload and type(payload) == "table" and payload.called then
|
||||
local result, count = state.payload:called(num_times, compare)
|
||||
arguments[1] = tostring(num_times or ">0")
|
||||
util.tinsert(arguments, 2, tostring(count))
|
||||
arguments.nofmt = arguments.nofmt or {}
|
||||
arguments.nofmt[1] = true
|
||||
arguments.nofmt[2] = true
|
||||
return result
|
||||
elseif payload and type(payload) == "function" then
|
||||
error(
|
||||
"When calling 'spy(aspy)', 'aspy' must not be the original function, but the spy function replacing the original",
|
||||
level)
|
||||
else
|
||||
error("'called' must be chained after 'spy(aspy)'", level)
|
||||
end
|
||||
end
|
||||
|
||||
local function called_at_least(state, arguments, level)
|
||||
local level = (level or 1) + 1
|
||||
return called(state, arguments, level, function(count, expected)
|
||||
return count >= expected
|
||||
end)
|
||||
end
|
||||
|
||||
local function called_at_most(state, arguments, level)
|
||||
local level = (level or 1) + 1
|
||||
return called(state, arguments, level, function(count, expected)
|
||||
return count <= expected
|
||||
end)
|
||||
end
|
||||
|
||||
local function called_more_than(state, arguments, level)
|
||||
local level = (level or 1) + 1
|
||||
return called(state, arguments, level, function(count, expected)
|
||||
return count > expected
|
||||
end)
|
||||
end
|
||||
|
||||
local function called_less_than(state, arguments, level)
|
||||
local level = (level or 1) + 1
|
||||
return called(state, arguments, level, function(count, expected)
|
||||
return count < expected
|
||||
end)
|
||||
end
|
||||
|
||||
assert:register("modifier", "spy", set_spy)
|
||||
assert:register("assertion", "returned_with", returned_with, "assertion.returned_with.positive",
|
||||
"assertion.returned_with.negative")
|
||||
assert:register("assertion", "called_with", called_with, "assertion.called_with.positive",
|
||||
"assertion.called_with.negative")
|
||||
assert:register("assertion", "called", called, "assertion.called.positive", "assertion.called.negative")
|
||||
assert:register("assertion", "called_at_least", called_at_least, "assertion.called_at_least.positive",
|
||||
"assertion.called_less_than.positive")
|
||||
assert:register("assertion", "called_at_most", called_at_most, "assertion.called_at_most.positive",
|
||||
"assertion.called_more_than.positive")
|
||||
assert:register("assertion", "called_more_than", called_more_than, "assertion.called_more_than.positive",
|
||||
"assertion.called_at_most.positive")
|
||||
assert:register("assertion", "called_less_than", called_less_than, "assertion.called_less_than.positive",
|
||||
"assertion.called_at_least.positive")
|
||||
|
||||
return setmetatable(spy, {
|
||||
__call = function(self, ...)
|
||||
return spy.new(...)
|
||||
end,
|
||||
})
|
||||
@ -0,0 +1,134 @@
|
||||
-- maintains a state of the assert engine in a linked-list fashion
|
||||
-- records; formatters, parameters, spies and stubs
|
||||
local state_mt = {
|
||||
__call = function(self)
|
||||
self:revert()
|
||||
end,
|
||||
}
|
||||
|
||||
local spies_mt = {
|
||||
__mode = "kv",
|
||||
}
|
||||
|
||||
local nilvalue = {} -- unique ID to refer to nil values for parameters
|
||||
|
||||
-- will hold the current state
|
||||
local current
|
||||
|
||||
-- exported module table
|
||||
local state = {}
|
||||
|
||||
------------------------------------------------------
|
||||
-- Reverts to a (specific) snapshot.
|
||||
-- @param self (optional) the snapshot to revert to. If not provided, it will revert to the last snapshot.
|
||||
state.revert = function(self)
|
||||
if not self then
|
||||
-- no snapshot given, so move 1 up
|
||||
self = current
|
||||
if not self.previous then
|
||||
-- top of list, no previous one, nothing to do
|
||||
return
|
||||
end
|
||||
end
|
||||
if getmetatable(self) ~= state_mt then
|
||||
error("Value provided is not a valid snapshot", 2)
|
||||
end
|
||||
|
||||
if self.next then
|
||||
self.next:revert()
|
||||
end
|
||||
-- revert formatters in 'last'
|
||||
self.formatters = {}
|
||||
-- revert parameters in 'last'
|
||||
self.parameters = {}
|
||||
-- revert spies/stubs in 'last'
|
||||
for s, _ in pairs(self.spies) do
|
||||
self.spies[s] = nil
|
||||
s:revert()
|
||||
end
|
||||
setmetatable(self, nil) -- invalidate as a snapshot
|
||||
current = self.previous
|
||||
current.next = nil
|
||||
end
|
||||
|
||||
------------------------------------------------------
|
||||
-- Creates a new snapshot.
|
||||
-- @return snapshot table
|
||||
state.snapshot = function()
|
||||
local new = setmetatable({
|
||||
formatters = {},
|
||||
parameters = {},
|
||||
spies = setmetatable({}, spies_mt),
|
||||
previous = current,
|
||||
revert = state.revert,
|
||||
}, state_mt)
|
||||
if current then
|
||||
current.next = new
|
||||
end
|
||||
current = new
|
||||
return current
|
||||
end
|
||||
|
||||
-- FORMATTERS
|
||||
state.add_formatter = function(callback)
|
||||
table.insert(current.formatters, 1, callback)
|
||||
end
|
||||
|
||||
state.remove_formatter = function(callback, s)
|
||||
s = s or current
|
||||
for i, v in ipairs(s.formatters) do
|
||||
if v == callback then
|
||||
table.remove(s.formatters, i)
|
||||
break
|
||||
end
|
||||
end
|
||||
-- wasn't found, so traverse up 1 state
|
||||
if s.previous then
|
||||
state.remove_formatter(callback, s.previous)
|
||||
end
|
||||
end
|
||||
|
||||
state.format_argument = function(val, s, fmtargs)
|
||||
s = s or current
|
||||
for _, fmt in ipairs(s.formatters) do
|
||||
local valfmt = fmt(val, fmtargs)
|
||||
if valfmt ~= nil then
|
||||
return valfmt
|
||||
end
|
||||
end
|
||||
-- nothing found, check snapshot 1 up in list
|
||||
if s.previous then
|
||||
return state.format_argument(val, s.previous, fmtargs)
|
||||
end
|
||||
return nil -- end of list, couldn't format
|
||||
end
|
||||
|
||||
-- PARAMETERS
|
||||
state.set_parameter = function(name, value)
|
||||
if value == nil then
|
||||
value = nilvalue
|
||||
end
|
||||
current.parameters[name] = value
|
||||
end
|
||||
|
||||
state.get_parameter = function(name, s)
|
||||
s = s or current
|
||||
local val = s.parameters[name]
|
||||
if val == nil and s.previous then
|
||||
-- not found, so check 1 up in list
|
||||
return state.get_parameter(name, s.previous)
|
||||
end
|
||||
if val ~= nilvalue then
|
||||
return val
|
||||
end
|
||||
return nil
|
||||
end
|
||||
|
||||
-- SPIES / STUBS
|
||||
state.add_spy = function(spy)
|
||||
current.spies[spy] = true
|
||||
end
|
||||
|
||||
state.snapshot() -- create initial state
|
||||
|
||||
return state
|
||||
@ -0,0 +1,109 @@
|
||||
-- module will return a stub module table
|
||||
local assert = require 'luassert.assert'
|
||||
local spy = require 'luassert.spy'
|
||||
local util = require 'luassert.util'
|
||||
local unpack = util.unpack
|
||||
local pack = util.pack
|
||||
|
||||
local stub = {}
|
||||
|
||||
function stub.new(object, key, ...)
|
||||
if object == nil and key == nil then
|
||||
-- called without arguments, create a 'blank' stub
|
||||
object = {}
|
||||
key = ""
|
||||
end
|
||||
local return_values = pack(...)
|
||||
assert(type(object) == "table" and key ~= nil,
|
||||
"stub.new(): Can only create stub on a table key, call with 2 params; table, key", util.errorlevel())
|
||||
assert(object[key] == nil or util.callable(object[key]),
|
||||
"stub.new(): The element for which to create a stub must either be callable, or be nil", util.errorlevel())
|
||||
local old_elem = object[key] -- keep existing element (might be nil!)
|
||||
|
||||
local fn = (return_values.n == 1 and util.callable(return_values[1]) and return_values[1])
|
||||
local defaultfunc = fn or function()
|
||||
return unpack(return_values)
|
||||
end
|
||||
local oncalls = {}
|
||||
local callbacks = {}
|
||||
local stubfunc = function(...)
|
||||
local args = util.make_arglist(...)
|
||||
local match = util.matchoncalls(oncalls, args)
|
||||
if match then
|
||||
return callbacks[match](...)
|
||||
end
|
||||
return defaultfunc(...)
|
||||
end
|
||||
|
||||
object[key] = stubfunc -- set the stubfunction
|
||||
local s = spy.on(object, key) -- create a spy on top of the stub function
|
||||
local spy_revert = s.revert -- keep created revert function
|
||||
|
||||
s.revert = function(self) -- wrap revert function to restore original element
|
||||
if not self.reverted then
|
||||
spy_revert(self)
|
||||
object[key] = old_elem
|
||||
self.reverted = true
|
||||
end
|
||||
return old_elem
|
||||
end
|
||||
|
||||
s.returns = function(...)
|
||||
local return_args = pack(...)
|
||||
defaultfunc = function()
|
||||
return unpack(return_args)
|
||||
end
|
||||
return s
|
||||
end
|
||||
|
||||
s.invokes = function(func)
|
||||
defaultfunc = function(...)
|
||||
return func(...)
|
||||
end
|
||||
return s
|
||||
end
|
||||
|
||||
s.by_default = {
|
||||
returns = s.returns,
|
||||
invokes = s.invokes,
|
||||
}
|
||||
|
||||
s.on_call_with = function(...)
|
||||
local match_args = util.make_arglist(...)
|
||||
match_args = util.copyargs(match_args)
|
||||
return {
|
||||
returns = function(...)
|
||||
local return_args = pack(...)
|
||||
table.insert(oncalls, match_args)
|
||||
callbacks[match_args] = function()
|
||||
return unpack(return_args)
|
||||
end
|
||||
return s
|
||||
end,
|
||||
invokes = function(func)
|
||||
table.insert(oncalls, match_args)
|
||||
callbacks[match_args] = function(...)
|
||||
return func(...)
|
||||
end
|
||||
return s
|
||||
end,
|
||||
}
|
||||
end
|
||||
|
||||
return s
|
||||
end
|
||||
|
||||
local function set_stub(state, arguments)
|
||||
state.payload = arguments[1]
|
||||
state.failure_message = arguments[2]
|
||||
end
|
||||
|
||||
assert:register("modifier", "stub", set_stub)
|
||||
|
||||
return setmetatable(stub, {
|
||||
__call = function(self, ...)
|
||||
-- stub originally was a function only. Now that it is a module table
|
||||
-- the __call method is required for backward compatibility
|
||||
return stub.new(...)
|
||||
end,
|
||||
})
|
||||
@ -0,0 +1,386 @@
|
||||
local util = {}
|
||||
local arglist_mt = {}
|
||||
|
||||
-- have pack/unpack both respect the 'n' field
|
||||
local _unpack = table.unpack or unpack
|
||||
local unpack = function(t, i, j)
|
||||
return _unpack(t, i or 1, j or t.n or #t)
|
||||
end
|
||||
local pack = function(...)
|
||||
return {
|
||||
n = select("#", ...),
|
||||
...,
|
||||
}
|
||||
end
|
||||
util.pack = pack
|
||||
util.unpack = unpack
|
||||
|
||||
function util.deepcompare(t1, t2, ignore_mt, cycles, thresh1, thresh2)
|
||||
local ty1 = type(t1)
|
||||
local ty2 = type(t2)
|
||||
-- non-table types can be directly compared
|
||||
if ty1 ~= 'table' or ty2 ~= 'table' then
|
||||
return t1 == t2
|
||||
end
|
||||
local mt1 = debug.getmetatable(t1)
|
||||
local mt2 = debug.getmetatable(t2)
|
||||
-- would equality be determined by metatable __eq?
|
||||
if mt1 and mt1 == mt2 and mt1.__eq then
|
||||
-- then use that unless asked not to
|
||||
if not ignore_mt then
|
||||
return t1 == t2
|
||||
end
|
||||
else -- we can skip the deep comparison below if t1 and t2 share identity
|
||||
if rawequal(t1, t2) then
|
||||
return true
|
||||
end
|
||||
end
|
||||
|
||||
-- handle recursive tables
|
||||
cycles = cycles or {{}, {}}
|
||||
thresh1, thresh2 = (thresh1 or 1), (thresh2 or 1)
|
||||
cycles[1][t1] = (cycles[1][t1] or 0)
|
||||
cycles[2][t2] = (cycles[2][t2] or 0)
|
||||
if cycles[1][t1] == 1 or cycles[2][t2] == 1 then
|
||||
thresh1 = cycles[1][t1] + 1
|
||||
thresh2 = cycles[2][t2] + 1
|
||||
end
|
||||
if cycles[1][t1] > thresh1 and cycles[2][t2] > thresh2 then
|
||||
return true
|
||||
end
|
||||
|
||||
cycles[1][t1] = cycles[1][t1] + 1
|
||||
cycles[2][t2] = cycles[2][t2] + 1
|
||||
|
||||
for k1, v1 in next, t1 do
|
||||
local v2 = t2[k1]
|
||||
if v2 == nil then
|
||||
return false, {k1}
|
||||
end
|
||||
|
||||
local same, crumbs = util.deepcompare(v1, v2, nil, cycles, thresh1, thresh2)
|
||||
if not same then
|
||||
crumbs = crumbs or {}
|
||||
table.insert(crumbs, k1)
|
||||
return false, crumbs
|
||||
end
|
||||
end
|
||||
for k2, _ in next, t2 do
|
||||
-- only check whether each element has a t1 counterpart, actual comparison
|
||||
-- has been done in first loop above
|
||||
if t1[k2] == nil then
|
||||
return false, {k2}
|
||||
end
|
||||
end
|
||||
|
||||
cycles[1][t1] = cycles[1][t1] - 1
|
||||
cycles[2][t2] = cycles[2][t2] - 1
|
||||
|
||||
return true
|
||||
end
|
||||
|
||||
function util.shallowcopy(t)
|
||||
if type(t) ~= "table" then
|
||||
return t
|
||||
end
|
||||
local copy = {}
|
||||
setmetatable(copy, getmetatable(t))
|
||||
for k, v in next, t do
|
||||
copy[k] = v
|
||||
end
|
||||
return copy
|
||||
end
|
||||
|
||||
function util.deepcopy(t, deepmt, cache)
|
||||
local spy = require 'luassert.spy'
|
||||
if type(t) ~= "table" then
|
||||
return t
|
||||
end
|
||||
local copy = {}
|
||||
|
||||
-- handle recursive tables
|
||||
local cache = cache or {}
|
||||
if cache[t] then
|
||||
return cache[t]
|
||||
end
|
||||
cache[t] = copy
|
||||
|
||||
for k, v in next, t do
|
||||
copy[k] = (spy.is_spy(v) and v or util.deepcopy(v, deepmt, cache))
|
||||
end
|
||||
if deepmt then
|
||||
debug.setmetatable(copy, util.deepcopy(debug.getmetatable(t, nil, cache)))
|
||||
else
|
||||
debug.setmetatable(copy, debug.getmetatable(t))
|
||||
end
|
||||
return copy
|
||||
end
|
||||
|
||||
-----------------------------------------------
|
||||
-- Copies arguments as a list of arguments
|
||||
-- @param args the arguments of which to copy
|
||||
-- @return the copy of the arguments
|
||||
function util.copyargs(args)
|
||||
local copy = {}
|
||||
setmetatable(copy, getmetatable(args))
|
||||
local match = require 'luassert.match'
|
||||
local spy = require 'luassert.spy'
|
||||
for k, v in pairs(args) do
|
||||
copy[k] = ((match.is_matcher(v) or spy.is_spy(v)) and v or util.deepcopy(v))
|
||||
end
|
||||
return {
|
||||
vals = copy,
|
||||
refs = util.shallowcopy(args),
|
||||
}
|
||||
end
|
||||
|
||||
-----------------------------------------------
|
||||
-- Clear an arguments or return values list from a table
|
||||
-- @param arglist the table to clear of arguments or return values and their count
|
||||
-- @return No return values
|
||||
function util.cleararglist(arglist)
|
||||
for idx = arglist.n, 1, -1 do
|
||||
util.tremove(arglist, idx)
|
||||
end
|
||||
arglist.n = nil
|
||||
end
|
||||
|
||||
-----------------------------------------------
|
||||
-- Test specs against an arglist in deepcopy and refs flavours.
|
||||
-- @param args deepcopy arglist
|
||||
-- @param argsrefs refs arglist
|
||||
-- @param specs arguments/return values to match against args/argsrefs
|
||||
-- @return true if specs match args/argsrefs, false otherwise
|
||||
local function matcharg(args, argrefs, specs)
|
||||
local match = require 'luassert.match'
|
||||
for idx, argval in pairs(args) do
|
||||
local spec = specs[idx]
|
||||
if match.is_matcher(spec) then
|
||||
if match.is_ref_matcher(spec) then
|
||||
argval = argrefs[idx]
|
||||
end
|
||||
if not spec(argval) then
|
||||
return false
|
||||
end
|
||||
elseif (spec == nil or not util.deepcompare(argval, spec)) then
|
||||
return false
|
||||
end
|
||||
end
|
||||
|
||||
for idx, spec in pairs(specs) do
|
||||
-- only check whether each element has an args counterpart,
|
||||
-- actual comparison has been done in first loop above
|
||||
local argval = args[idx]
|
||||
if argval == nil then
|
||||
-- no args counterpart, so try to compare using matcher
|
||||
if match.is_matcher(spec) then
|
||||
if not spec(argval) then
|
||||
return false
|
||||
end
|
||||
else
|
||||
return false
|
||||
end
|
||||
end
|
||||
end
|
||||
return true
|
||||
end
|
||||
|
||||
-----------------------------------------------
|
||||
-- Find matching arguments/return values in a saved list of
|
||||
-- arguments/returned values.
|
||||
-- @param invocations_list list of arguments/returned values to search (list of lists)
|
||||
-- @param specs arguments/return values to match against argslist
|
||||
-- @return the last matching arguments/returned values if a match is found, otherwise nil
|
||||
function util.matchargs(invocations_list, specs)
|
||||
-- Search the arguments/returned values last to first to give the
|
||||
-- most helpful answer possible. In the cases where you can place
|
||||
-- your assertions between calls to check this gives you the best
|
||||
-- information if no calls match. In the cases where you can't do
|
||||
-- that there is no good way to predict what would work best.
|
||||
assert(not util.is_arglist(invocations_list), "expected a list of arglist-object, got an arglist")
|
||||
for ii = #invocations_list, 1, -1 do
|
||||
local val = invocations_list[ii]
|
||||
if matcharg(val.vals, val.refs, specs) then
|
||||
return val
|
||||
end
|
||||
end
|
||||
return nil
|
||||
end
|
||||
|
||||
-----------------------------------------------
|
||||
-- Find matching oncall for an actual call.
|
||||
-- @param oncalls list of oncalls to search
|
||||
-- @param args actual call argslist to match against
|
||||
-- @return the first matching oncall if a match is found, otherwise nil
|
||||
function util.matchoncalls(oncalls, args)
|
||||
for _, callspecs in ipairs(oncalls) do
|
||||
-- This lookup is done immediately on *args* passing into the stub
|
||||
-- so pass *args* as both *args* and *argsref* without copying
|
||||
-- either.
|
||||
if matcharg(args, args, callspecs.vals) then
|
||||
return callspecs
|
||||
end
|
||||
end
|
||||
return nil
|
||||
end
|
||||
|
||||
-----------------------------------------------
|
||||
-- table.insert() replacement that respects nil values.
|
||||
-- The function will use table field 'n' as indicator of the
|
||||
-- table length, if not set, it will be added.
|
||||
-- @param t table into which to insert
|
||||
-- @param pos (optional) position in table where to insert. NOTE: not optional if you want to insert a nil-value!
|
||||
-- @param val value to insert
|
||||
-- @return No return values
|
||||
function util.tinsert(...)
|
||||
-- check optional POS value
|
||||
local args = {...}
|
||||
local c = select('#', ...)
|
||||
local t = args[1]
|
||||
local pos = args[2]
|
||||
local val = args[3]
|
||||
if c < 3 then
|
||||
val = pos
|
||||
pos = nil
|
||||
end
|
||||
-- set length indicator n if not present (+1)
|
||||
t.n = (t.n or #t) + 1
|
||||
if not pos then
|
||||
pos = t.n
|
||||
elseif pos > t.n then
|
||||
-- out of our range
|
||||
t[pos] = val
|
||||
t.n = pos
|
||||
end
|
||||
-- shift everything up 1 pos
|
||||
for i = t.n, pos + 1, -1 do
|
||||
t[i] = t[i - 1]
|
||||
end
|
||||
-- add element to be inserted
|
||||
t[pos] = val
|
||||
end
|
||||
-----------------------------------------------
|
||||
-- table.remove() replacement that respects nil values.
|
||||
-- The function will use table field 'n' as indicator of the
|
||||
-- table length, if not set, it will be added.
|
||||
-- @param t table from which to remove
|
||||
-- @param pos (optional) position in table to remove
|
||||
-- @return No return values
|
||||
function util.tremove(t, pos)
|
||||
-- set length indicator n if not present (+1)
|
||||
t.n = t.n or #t
|
||||
if not pos then
|
||||
pos = t.n
|
||||
elseif pos > t.n then
|
||||
local removed = t[pos]
|
||||
-- out of our range
|
||||
t[pos] = nil
|
||||
return removed
|
||||
end
|
||||
local removed = t[pos]
|
||||
-- shift everything up 1 pos
|
||||
for i = pos, t.n do
|
||||
t[i] = t[i + 1]
|
||||
end
|
||||
-- set size, clean last
|
||||
t[t.n] = nil
|
||||
t.n = t.n - 1
|
||||
return removed
|
||||
end
|
||||
|
||||
-----------------------------------------------
|
||||
-- Checks an element to be callable.
|
||||
-- The type must either be a function or have a metatable
|
||||
-- containing an '__call' function.
|
||||
-- @param object element to inspect on being callable or not
|
||||
-- @return boolean, true if the object is callable
|
||||
function util.callable(object)
|
||||
return type(object) == "function" or type((debug.getmetatable(object) or {}).__call) == "function"
|
||||
end
|
||||
-----------------------------------------------
|
||||
-- Checks an element has tostring.
|
||||
-- The type must either be a string or have a metatable
|
||||
-- containing an '__tostring' function.
|
||||
-- @param object element to inspect on having tostring or not
|
||||
-- @return boolean, true if the object has tostring
|
||||
function util.hastostring(object)
|
||||
return type(object) == "string" or type((debug.getmetatable(object) or {}).__tostring) == "function"
|
||||
end
|
||||
|
||||
-----------------------------------------------
|
||||
-- Find the first level, not defined in the same file as the caller's
|
||||
-- code file to properly report an error.
|
||||
-- @param level the level to use as the caller's source file
|
||||
-- @return number, the level of which to report an error
|
||||
function util.errorlevel(level)
|
||||
local level = (level or 1) + 1 -- add one to get level of the caller
|
||||
local info = debug.getinfo(level)
|
||||
local source = (info or {}).source
|
||||
local file = source
|
||||
while file and (file == source or source == "=(tail call)") do
|
||||
level = level + 1
|
||||
info = debug.getinfo(level)
|
||||
source = (info or {}).source
|
||||
end
|
||||
if level > 1 then
|
||||
level = level - 1
|
||||
end -- deduct call to errorlevel() itself
|
||||
return level
|
||||
end
|
||||
|
||||
-----------------------------------------------
|
||||
-- Extract modifier and namespace keys from list of tokens.
|
||||
-- @param nspace the namespace from which to match tokens
|
||||
-- @param tokens list of tokens to search for keys
|
||||
-- @return table, list of keys that were extracted
|
||||
function util.extract_keys(nspace, tokens)
|
||||
local namespace = require 'luassert.namespaces'
|
||||
|
||||
-- find valid keys by coalescing tokens as needed, starting from the end
|
||||
local keys = {}
|
||||
local key = nil
|
||||
local i = #tokens
|
||||
while i > 0 do
|
||||
local token = tokens[i]
|
||||
key = key and (token .. '_' .. key) or token
|
||||
|
||||
-- find longest matching key in the given namespace
|
||||
local longkey = i > 1 and (tokens[i - 1] .. '_' .. key) or nil
|
||||
while i > 1 and longkey and namespace[nspace][longkey] do
|
||||
key = longkey
|
||||
i = i - 1
|
||||
token = tokens[i]
|
||||
longkey = (token .. '_' .. key)
|
||||
end
|
||||
|
||||
if namespace.modifier[key] or namespace[nspace][key] then
|
||||
table.insert(keys, 1, key)
|
||||
key = nil
|
||||
end
|
||||
i = i - 1
|
||||
end
|
||||
|
||||
-- if there's anything left we didn't recognize it
|
||||
if key then
|
||||
error("luassert: unknown modifier/" .. nspace .. ": '" .. key .. "'", util.errorlevel(2))
|
||||
end
|
||||
|
||||
return keys
|
||||
end
|
||||
|
||||
-----------------------------------------------
|
||||
-- store argument list for return values of a function in a table.
|
||||
-- The table will get a metatable to identify it as an arglist
|
||||
function util.make_arglist(...)
|
||||
local arglist = {...}
|
||||
arglist.n = select('#', ...) -- add values count for trailing nils
|
||||
return setmetatable(arglist, arglist_mt)
|
||||
end
|
||||
|
||||
-----------------------------------------------
|
||||
-- check a table to be an arglist type.
|
||||
function util.is_arglist(object)
|
||||
return getmetatable(object) == arglist_mt
|
||||
end
|
||||
|
||||
return util
|
||||
Loading…
Reference in New Issue