diff --git a/framework/3rd/greatest/greatest.h b/framework/3rd/greatest/greatest.h new file mode 100644 index 0000000..9022c95 --- /dev/null +++ b/framework/3rd/greatest/greatest.h @@ -0,0 +1,1266 @@ +/* + * Copyright (c) 2011-2021 Scott Vokes + * + * Permission to use, copy, modify, and/or distribute this software for any + * purpose with or without fee is hereby granted, provided that the above + * copyright notice and this permission notice appear in all copies. + * + * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES + * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF + * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR + * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES + * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN + * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF + * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + */ + +#ifndef GREATEST_H +#define GREATEST_H + +#if defined(__cplusplus) && !defined(GREATEST_NO_EXTERN_CPLUSPLUS) +extern "C" { +#endif + +/* 1.5.0 */ +#define GREATEST_VERSION_MAJOR 1 +#define GREATEST_VERSION_MINOR 5 +#define GREATEST_VERSION_PATCH 0 + +/* A unit testing system for C, contained in 1 file. + * It doesn't use dynamic allocation or depend on anything + * beyond ANSI C89. + * + * An up-to-date version can be found at: + * https://github.com/silentbicycle/greatest/ + */ + + +/********************************************************************* + * Minimal test runner template + *********************************************************************/ +#if 0 + +#include "greatest.h" + +TEST foo_should_foo(void) { + PASS(); +} + +static void setup_cb(void *data) { + printf("setup callback for each test case\n"); +} + +static void teardown_cb(void *data) { + printf("teardown callback for each test case\n"); +} + +SUITE(suite) { + /* Optional setup/teardown callbacks which will be run before/after + * every test case. If using a test suite, they will be cleared when + * the suite finishes. */ + SET_SETUP(setup_cb, voidp_to_callback_data); + SET_TEARDOWN(teardown_cb, voidp_to_callback_data); + + RUN_TEST(foo_should_foo); +} + +/* Add definitions that need to be in the test runner's main file. */ +GREATEST_MAIN_DEFS(); + +/* Set up, run suite(s) of tests, report pass/fail/skip stats. */ +int run_tests(void) { + GREATEST_INIT(); /* init. greatest internals */ + /* List of suites to run (if any). */ + RUN_SUITE(suite); + + /* Tests can also be run directly, without using test suites. */ + RUN_TEST(foo_should_foo); + + GREATEST_PRINT_REPORT(); /* display results */ + return greatest_all_passed(); +} + +/* main(), for a standalone command-line test runner. + * This replaces run_tests above, and adds command line option + * handling and exiting with a pass/fail status. */ +int main(int argc, char **argv) { + GREATEST_MAIN_BEGIN(); /* init & parse command-line args */ + RUN_SUITE(suite); + GREATEST_MAIN_END(); /* display results */ +} + +#endif +/*********************************************************************/ + + +#include +#include +#include +#include + +/*********** + * Options * + ***********/ + +/* Default column width for non-verbose output. */ +#ifndef GREATEST_DEFAULT_WIDTH +#define GREATEST_DEFAULT_WIDTH 72 +#endif + +/* FILE *, for test logging. */ +#ifndef GREATEST_STDOUT +#define GREATEST_STDOUT stdout +#endif + +/* Remove GREATEST_ prefix from most commonly used symbols? */ +#ifndef GREATEST_USE_ABBREVS +#define GREATEST_USE_ABBREVS 1 +#endif + +/* Set to 0 to disable all use of setjmp/longjmp. */ +#ifndef GREATEST_USE_LONGJMP +#define GREATEST_USE_LONGJMP 0 +#endif + +/* Make it possible to replace fprintf with another + * function with the same interface. */ +#ifndef GREATEST_FPRINTF +#define GREATEST_FPRINTF fprintf +#endif + +#if GREATEST_USE_LONGJMP +#include +#endif + +/* Set to 0 to disable all use of time.h / clock(). */ +#ifndef GREATEST_USE_TIME +#define GREATEST_USE_TIME 1 +#endif + +#if GREATEST_USE_TIME +#include +#endif + +/* Floating point type, for ASSERT_IN_RANGE. */ +#ifndef GREATEST_FLOAT +#define GREATEST_FLOAT double +#define GREATEST_FLOAT_FMT "%g" +#endif + +/* Size of buffer for test name + optional '_' separator and suffix */ +#ifndef GREATEST_TESTNAME_BUF_SIZE +#define GREATEST_TESTNAME_BUF_SIZE 128 +#endif + + +/********* + * Types * + *********/ + +/* Info for the current running suite. */ +typedef struct greatest_suite_info { + unsigned int tests_run; + unsigned int passed; + unsigned int failed; + unsigned int skipped; + +#if GREATEST_USE_TIME + /* timers, pre/post running suite and individual tests */ + clock_t pre_suite; + clock_t post_suite; + clock_t pre_test; + clock_t post_test; +#endif +} greatest_suite_info; + +/* Type for a suite function. */ +typedef void greatest_suite_cb(void); + +/* Types for setup/teardown callbacks. If non-NULL, these will be run + * and passed the pointer to their additional data. */ +typedef void greatest_setup_cb(void *udata); +typedef void greatest_teardown_cb(void *udata); + +/* Type for an equality comparison between two pointers of the same type. + * Should return non-0 if equal, otherwise 0. + * UDATA is a closure value, passed through from ASSERT_EQUAL_T[m]. */ +typedef int greatest_equal_cb(const void *expd, const void *got, void *udata); + +/* Type for a callback that prints a value pointed to by T. + * Return value has the same meaning as printf's. + * UDATA is a closure value, passed through from ASSERT_EQUAL_T[m]. */ +typedef int greatest_printf_cb(const void *t, void *udata); + +/* Callbacks for an arbitrary type; needed for type-specific + * comparisons via GREATEST_ASSERT_EQUAL_T[m].*/ +typedef struct greatest_type_info { + greatest_equal_cb *equal; + greatest_printf_cb *print; +} greatest_type_info; + +typedef struct greatest_memory_cmp_env { + const unsigned char *exp; + const unsigned char *got; + size_t size; +} greatest_memory_cmp_env; + +/* Callbacks for string and raw memory types. */ +extern greatest_type_info greatest_type_info_string; +extern greatest_type_info greatest_type_info_memory; + +typedef enum { + GREATEST_FLAG_FIRST_FAIL = 0x01, + GREATEST_FLAG_LIST_ONLY = 0x02, + GREATEST_FLAG_ABORT_ON_FAIL = 0x04 +} greatest_flag_t; + +/* Internal state for a PRNG, used to shuffle test order. */ +struct greatest_prng { + unsigned char random_order; /* use random ordering? */ + unsigned char initialized; /* is random ordering initialized? */ + unsigned char pad_0[6]; + unsigned long state; /* PRNG state */ + unsigned long count; /* how many tests, this pass */ + unsigned long count_ceil; /* total number of tests */ + unsigned long count_run; /* total tests run */ + unsigned long a; /* LCG multiplier */ + unsigned long c; /* LCG increment */ + unsigned long m; /* LCG modulus, based on count_ceil */ +}; + +/* Struct containing all test runner state. */ +typedef struct greatest_run_info { + unsigned char flags; + unsigned char verbosity; + unsigned char running_test; /* guard for nested RUN_TEST calls */ + unsigned char exact_name_match; + + unsigned int tests_run; /* total test count */ + + /* currently running test suite */ + greatest_suite_info suite; + + /* overall pass/fail/skip counts */ + unsigned int passed; + unsigned int failed; + unsigned int skipped; + unsigned int assertions; + + /* info to print about the most recent failure */ + unsigned int fail_line; + unsigned int pad_1; + const char *fail_file; + const char *msg; + + /* current setup/teardown hooks and userdata */ + greatest_setup_cb *setup; + void *setup_udata; + greatest_teardown_cb *teardown; + void *teardown_udata; + + /* formatting info for ".....s...F"-style output */ + unsigned int col; + unsigned int width; + + /* only run a specific suite or test */ + const char *suite_filter; + const char *test_filter; + const char *test_exclude; + const char *name_suffix; /* print suffix with test name */ + char name_buf[GREATEST_TESTNAME_BUF_SIZE]; + + struct greatest_prng prng[2]; /* 0: suites, 1: tests */ + +#if GREATEST_USE_TIME + /* overall timers */ + clock_t begin; + clock_t end; +#endif + +#if GREATEST_USE_LONGJMP + int pad_jmp_buf; + unsigned char pad_2[4]; + jmp_buf jump_dest; +#endif +} greatest_run_info; + +struct greatest_report_t { + /* overall pass/fail/skip counts */ + unsigned int passed; + unsigned int failed; + unsigned int skipped; + unsigned int assertions; +}; + +/* Global var for the current testing context. + * Initialized by GREATEST_MAIN_DEFS(). */ +extern greatest_run_info greatest_info; + +/* Type for ASSERT_ENUM_EQ's ENUM_STR argument. */ +typedef const char *greatest_enum_str_fun(int value); + + +/********************** + * Exported functions * + **********************/ + +/* These are used internally by greatest macros. */ +int greatest_test_pre(const char *name); +void greatest_test_post(int res); +int greatest_do_assert_equal_t(const void *expd, const void *got, + greatest_type_info *type_info, void *udata); +void greatest_prng_init_first_pass(int id); +int greatest_prng_init_second_pass(int id, unsigned long seed); +void greatest_prng_step(int id); + +/* These are part of the public greatest API. */ +void GREATEST_SET_SETUP_CB(greatest_setup_cb *cb, void *udata); +void GREATEST_SET_TEARDOWN_CB(greatest_teardown_cb *cb, void *udata); +void GREATEST_INIT(void); +void GREATEST_PRINT_REPORT(void); +int greatest_all_passed(void); +void greatest_set_suite_filter(const char *filter); +void greatest_set_test_filter(const char *filter); +void greatest_set_test_exclude(const char *filter); +void greatest_set_exact_name_match(void); +void greatest_stop_at_first_fail(void); +void greatest_abort_on_fail(void); +void greatest_list_only(void); +void greatest_get_report(struct greatest_report_t *report); +unsigned int greatest_get_verbosity(void); +void greatest_set_verbosity(unsigned int verbosity); +void greatest_set_flag(greatest_flag_t flag); +void greatest_set_test_suffix(const char *suffix); + + +/******************** +* Language Support * +********************/ + +/* If __VA_ARGS__ (C99) is supported, allow parametric testing +* without needing to manually manage the argument struct. */ +#if (defined(__STDC_VERSION__) && __STDC_VERSION__ >= 19901L) || \ + (defined(_MSC_VER) && _MSC_VER >= 1800) +#define GREATEST_VA_ARGS +#endif + + +/********** + * Macros * + **********/ + +/* Define a suite. (The duplication is intentional -- it eliminates + * a warning from -Wmissing-declarations.) */ +#define GREATEST_SUITE(NAME) void NAME(void); void NAME(void) + +/* Declare a suite, provided by another compilation unit. */ +#define GREATEST_SUITE_EXTERN(NAME) void NAME(void) + +/* Start defining a test function. + * The arguments are not included, to allow parametric testing. */ +#define GREATEST_TEST static enum greatest_test_res + +/* PASS/FAIL/SKIP result from a test. Used internally. */ +typedef enum greatest_test_res { + GREATEST_TEST_RES_PASS = 0, + GREATEST_TEST_RES_FAIL = -1, + GREATEST_TEST_RES_SKIP = 1 +} greatest_test_res; + +/* Run a suite. */ +#define GREATEST_RUN_SUITE(S_NAME) greatest_run_suite(S_NAME, #S_NAME) + +/* Run a test in the current suite. */ +#define GREATEST_RUN_TEST(TEST) \ + do { \ + if (greatest_test_pre(#TEST) == 1) { \ + enum greatest_test_res res = GREATEST_SAVE_CONTEXT(); \ + if (res == GREATEST_TEST_RES_PASS) { \ + res = TEST(); \ + } \ + greatest_test_post(res); \ + } \ + } while (0) + +/* Ignore a test, don't warn about it being unused. */ +#define GREATEST_IGNORE_TEST(TEST) (void)TEST + +/* Run a test in the current suite with one void * argument, + * which can be a pointer to a struct with multiple arguments. */ +#define GREATEST_RUN_TEST1(TEST, ENV) \ + do { \ + if (greatest_test_pre(#TEST) == 1) { \ + enum greatest_test_res res = GREATEST_SAVE_CONTEXT(); \ + if (res == GREATEST_TEST_RES_PASS) { \ + res = TEST(ENV); \ + } \ + greatest_test_post(res); \ + } \ + } while (0) + +#ifdef GREATEST_VA_ARGS +#define GREATEST_RUN_TESTp(TEST, ...) \ + do { \ + if (greatest_test_pre(#TEST) == 1) { \ + enum greatest_test_res res = GREATEST_SAVE_CONTEXT(); \ + if (res == GREATEST_TEST_RES_PASS) { \ + res = TEST(__VA_ARGS__); \ + } \ + greatest_test_post(res); \ + } \ + } while (0) +#endif + + +/* Check if the test runner is in verbose mode. */ +#define GREATEST_IS_VERBOSE() ((greatest_info.verbosity) > 0) +#define GREATEST_LIST_ONLY() \ + (greatest_info.flags & GREATEST_FLAG_LIST_ONLY) +#define GREATEST_FIRST_FAIL() \ + (greatest_info.flags & GREATEST_FLAG_FIRST_FAIL) +#define GREATEST_ABORT_ON_FAIL() \ + (greatest_info.flags & GREATEST_FLAG_ABORT_ON_FAIL) +#define GREATEST_FAILURE_ABORT() \ + (GREATEST_FIRST_FAIL() && \ + (greatest_info.suite.failed > 0 || greatest_info.failed > 0)) + +/* Message-less forms of tests defined below. */ +#define GREATEST_PASS() GREATEST_PASSm(NULL) +#define GREATEST_FAIL() GREATEST_FAILm(NULL) +#define GREATEST_SKIP() GREATEST_SKIPm(NULL) +#define GREATEST_ASSERT(COND) \ + GREATEST_ASSERTm(#COND, COND) +#define GREATEST_ASSERT_OR_LONGJMP(COND) \ + GREATEST_ASSERT_OR_LONGJMPm(#COND, COND) +#define GREATEST_ASSERT_FALSE(COND) \ + GREATEST_ASSERT_FALSEm(#COND, COND) +#define GREATEST_ASSERT_EQ(EXP, GOT) \ + GREATEST_ASSERT_EQm(#EXP " != " #GOT, EXP, GOT) +#define GREATEST_ASSERT_NEQ(EXP, GOT) \ + GREATEST_ASSERT_NEQm(#EXP " == " #GOT, EXP, GOT) +#define GREATEST_ASSERT_GT(EXP, GOT) \ + GREATEST_ASSERT_GTm(#EXP " <= " #GOT, EXP, GOT) +#define GREATEST_ASSERT_GTE(EXP, GOT) \ + GREATEST_ASSERT_GTEm(#EXP " < " #GOT, EXP, GOT) +#define GREATEST_ASSERT_LT(EXP, GOT) \ + GREATEST_ASSERT_LTm(#EXP " >= " #GOT, EXP, GOT) +#define GREATEST_ASSERT_LTE(EXP, GOT) \ + GREATEST_ASSERT_LTEm(#EXP " > " #GOT, EXP, GOT) +#define GREATEST_ASSERT_EQ_FMT(EXP, GOT, FMT) \ + GREATEST_ASSERT_EQ_FMTm(#EXP " != " #GOT, EXP, GOT, FMT) +#define GREATEST_ASSERT_IN_RANGE(EXP, GOT, TOL) \ + GREATEST_ASSERT_IN_RANGEm(#EXP " != " #GOT " +/- " #TOL, EXP, GOT, TOL) +#define GREATEST_ASSERT_EQUAL_T(EXP, GOT, TYPE_INFO, UDATA) \ + GREATEST_ASSERT_EQUAL_Tm(#EXP " != " #GOT, EXP, GOT, TYPE_INFO, UDATA) +#define GREATEST_ASSERT_STR_EQ(EXP, GOT) \ + GREATEST_ASSERT_STR_EQm(#EXP " != " #GOT, EXP, GOT) +#define GREATEST_ASSERT_STRN_EQ(EXP, GOT, SIZE) \ + GREATEST_ASSERT_STRN_EQm(#EXP " != " #GOT, EXP, GOT, SIZE) +#define GREATEST_ASSERT_MEM_EQ(EXP, GOT, SIZE) \ + GREATEST_ASSERT_MEM_EQm(#EXP " != " #GOT, EXP, GOT, SIZE) +#define GREATEST_ASSERT_ENUM_EQ(EXP, GOT, ENUM_STR) \ + GREATEST_ASSERT_ENUM_EQm(#EXP " != " #GOT, EXP, GOT, ENUM_STR) + +/* The following forms take an additional message argument first, + * to be displayed by the test runner. */ + +/* Fail if a condition is not true, with message. */ +#define GREATEST_ASSERTm(MSG, COND) \ + do { \ + greatest_info.assertions++; \ + if (!(COND)) { GREATEST_FAILm(MSG); } \ + } while (0) + +/* Fail if a condition is not true, longjmping out of test. */ +#define GREATEST_ASSERT_OR_LONGJMPm(MSG, COND) \ + do { \ + greatest_info.assertions++; \ + if (!(COND)) { GREATEST_FAIL_WITH_LONGJMPm(MSG); } \ + } while (0) + +/* Fail if a condition is not false, with message. */ +#define GREATEST_ASSERT_FALSEm(MSG, COND) \ + do { \ + greatest_info.assertions++; \ + if ((COND)) { GREATEST_FAILm(MSG); } \ + } while (0) + +/* Internal macro for relational assertions */ +#define GREATEST__REL(REL, MSG, EXP, GOT) \ + do { \ + greatest_info.assertions++; \ + if (!((EXP) REL (GOT))) { GREATEST_FAILm(MSG); } \ + } while (0) + +/* Fail if EXP is not ==, !=, >, <, >=, or <= to GOT. */ +#define GREATEST_ASSERT_EQm(MSG,E,G) GREATEST__REL(==, MSG,E,G) +#define GREATEST_ASSERT_NEQm(MSG,E,G) GREATEST__REL(!=, MSG,E,G) +#define GREATEST_ASSERT_GTm(MSG,E,G) GREATEST__REL(>, MSG,E,G) +#define GREATEST_ASSERT_GTEm(MSG,E,G) GREATEST__REL(>=, MSG,E,G) +#define GREATEST_ASSERT_LTm(MSG,E,G) GREATEST__REL(<, MSG,E,G) +#define GREATEST_ASSERT_LTEm(MSG,E,G) GREATEST__REL(<=, MSG,E,G) + +/* Fail if EXP != GOT (equality comparison by ==). + * Warning: FMT, EXP, and GOT will be evaluated more + * than once on failure. */ +#define GREATEST_ASSERT_EQ_FMTm(MSG, EXP, GOT, FMT) \ + do { \ + greatest_info.assertions++; \ + if ((EXP) != (GOT)) { \ + GREATEST_FPRINTF(GREATEST_STDOUT, "\nExpected: "); \ + GREATEST_FPRINTF(GREATEST_STDOUT, FMT, EXP); \ + GREATEST_FPRINTF(GREATEST_STDOUT, "\n Got: "); \ + GREATEST_FPRINTF(GREATEST_STDOUT, FMT, GOT); \ + GREATEST_FPRINTF(GREATEST_STDOUT, "\n"); \ + GREATEST_FAILm(MSG); \ + } \ + } while (0) + +/* Fail if EXP is not equal to GOT, printing enum IDs. */ +#define GREATEST_ASSERT_ENUM_EQm(MSG, EXP, GOT, ENUM_STR) \ + do { \ + int greatest_EXP = (int)(EXP); \ + int greatest_GOT = (int)(GOT); \ + greatest_enum_str_fun *greatest_ENUM_STR = ENUM_STR; \ + if (greatest_EXP != greatest_GOT) { \ + GREATEST_FPRINTF(GREATEST_STDOUT, "\nExpected: %s", \ + greatest_ENUM_STR(greatest_EXP)); \ + GREATEST_FPRINTF(GREATEST_STDOUT, "\n Got: %s\n", \ + greatest_ENUM_STR(greatest_GOT)); \ + GREATEST_FAILm(MSG); \ + } \ + } while (0) \ + +/* Fail if GOT not in range of EXP +|- TOL. */ +#define GREATEST_ASSERT_IN_RANGEm(MSG, EXP, GOT, TOL) \ + do { \ + GREATEST_FLOAT greatest_EXP = (EXP); \ + GREATEST_FLOAT greatest_GOT = (GOT); \ + GREATEST_FLOAT greatest_TOL = (TOL); \ + greatest_info.assertions++; \ + if ((greatest_EXP > greatest_GOT && \ + greatest_EXP - greatest_GOT > greatest_TOL) || \ + (greatest_EXP < greatest_GOT && \ + greatest_GOT - greatest_EXP > greatest_TOL)) { \ + GREATEST_FPRINTF(GREATEST_STDOUT, \ + "\nExpected: " GREATEST_FLOAT_FMT \ + " +/- " GREATEST_FLOAT_FMT \ + "\n Got: " GREATEST_FLOAT_FMT \ + "\n", \ + greatest_EXP, greatest_TOL, greatest_GOT); \ + GREATEST_FAILm(MSG); \ + } \ + } while (0) + +/* Fail if EXP is not equal to GOT, according to strcmp. */ +#define GREATEST_ASSERT_STR_EQm(MSG, EXP, GOT) \ + do { \ + GREATEST_ASSERT_EQUAL_Tm(MSG, EXP, GOT, \ + &greatest_type_info_string, NULL); \ + } while (0) \ + +/* Fail if EXP is not equal to GOT, according to strncmp. */ +#define GREATEST_ASSERT_STRN_EQm(MSG, EXP, GOT, SIZE) \ + do { \ + size_t size = SIZE; \ + GREATEST_ASSERT_EQUAL_Tm(MSG, EXP, GOT, \ + &greatest_type_info_string, &size); \ + } while (0) \ + +/* Fail if EXP is not equal to GOT, according to memcmp. */ +#define GREATEST_ASSERT_MEM_EQm(MSG, EXP, GOT, SIZE) \ + do { \ + greatest_memory_cmp_env env; \ + env.exp = (const unsigned char *)EXP; \ + env.got = (const unsigned char *)GOT; \ + env.size = SIZE; \ + GREATEST_ASSERT_EQUAL_Tm(MSG, env.exp, env.got, \ + &greatest_type_info_memory, &env); \ + } while (0) \ + +/* Fail if EXP is not equal to GOT, according to a comparison + * callback in TYPE_INFO. If they are not equal, optionally use a + * print callback in TYPE_INFO to print them. */ +#define GREATEST_ASSERT_EQUAL_Tm(MSG, EXP, GOT, TYPE_INFO, UDATA) \ + do { \ + greatest_type_info *type_info = (TYPE_INFO); \ + greatest_info.assertions++; \ + if (!greatest_do_assert_equal_t(EXP, GOT, \ + type_info, UDATA)) { \ + if (type_info == NULL || type_info->equal == NULL) { \ + GREATEST_FAILm("type_info->equal callback missing!"); \ + } else { \ + GREATEST_FAILm(MSG); \ + } \ + } \ + } while (0) \ + +/* Pass. */ +#define GREATEST_PASSm(MSG) \ + do { \ + greatest_info.msg = MSG; \ + return GREATEST_TEST_RES_PASS; \ + } while (0) + +/* Fail. */ +#define GREATEST_FAILm(MSG) \ + do { \ + greatest_info.fail_file = __FILE__; \ + greatest_info.fail_line = __LINE__; \ + greatest_info.msg = MSG; \ + if (GREATEST_ABORT_ON_FAIL()) { abort(); } \ + return GREATEST_TEST_RES_FAIL; \ + } while (0) + +/* Optional GREATEST_FAILm variant that longjmps. */ +#if GREATEST_USE_LONGJMP +#define GREATEST_FAIL_WITH_LONGJMP() GREATEST_FAIL_WITH_LONGJMPm(NULL) +#define GREATEST_FAIL_WITH_LONGJMPm(MSG) \ + do { \ + greatest_info.fail_file = __FILE__; \ + greatest_info.fail_line = __LINE__; \ + greatest_info.msg = MSG; \ + longjmp(greatest_info.jump_dest, GREATEST_TEST_RES_FAIL); \ + } while (0) +#endif + +/* Skip the current test. */ +#define GREATEST_SKIPm(MSG) \ + do { \ + greatest_info.msg = MSG; \ + return GREATEST_TEST_RES_SKIP; \ + } while (0) + +/* Check the result of a subfunction using ASSERT, etc. */ +#define GREATEST_CHECK_CALL(RES) \ + do { \ + enum greatest_test_res greatest_RES = RES; \ + if (greatest_RES != GREATEST_TEST_RES_PASS) { \ + return greatest_RES; \ + } \ + } while (0) \ + +#if GREATEST_USE_TIME +#define GREATEST_SET_TIME(NAME) \ + NAME = clock(); \ + if (NAME == (clock_t) -1) { \ + GREATEST_FPRINTF(GREATEST_STDOUT, \ + "clock error: %s\n", #NAME); \ + exit(EXIT_FAILURE); \ + } + +#define GREATEST_CLOCK_DIFF(C1, C2) \ + GREATEST_FPRINTF(GREATEST_STDOUT, " (%lu ticks, %.3f sec)", \ + (long unsigned int) (C2) - (long unsigned int)(C1), \ + (double)((C2) - (C1)) / (1.0 * (double)CLOCKS_PER_SEC)) +#else +#define GREATEST_SET_TIME(UNUSED) +#define GREATEST_CLOCK_DIFF(UNUSED1, UNUSED2) +#endif + +#if GREATEST_USE_LONGJMP +#define GREATEST_SAVE_CONTEXT() \ + /* setjmp returns 0 (GREATEST_TEST_RES_PASS) on first call * \ + * so the test runs, then RES_FAIL from FAIL_WITH_LONGJMP. */ \ + ((enum greatest_test_res)(setjmp(greatest_info.jump_dest))) +#else +#define GREATEST_SAVE_CONTEXT() \ + /*a no-op, since setjmp/longjmp aren't being used */ \ + GREATEST_TEST_RES_PASS +#endif + +/* Run every suite / test function run within BODY in pseudo-random + * order, seeded by SEED. (The top 3 bits of the seed are ignored.) + * + * This should be called like: + * GREATEST_SHUFFLE_TESTS(seed, { + * GREATEST_RUN_TEST(some_test); + * GREATEST_RUN_TEST(some_other_test); + * GREATEST_RUN_TEST(yet_another_test); + * }); + * + * Note that the body of the second argument will be evaluated + * multiple times. */ +#define GREATEST_SHUFFLE_SUITES(SD, BODY) GREATEST_SHUFFLE(0, SD, BODY) +#define GREATEST_SHUFFLE_TESTS(SD, BODY) GREATEST_SHUFFLE(1, SD, BODY) +#define GREATEST_SHUFFLE(ID, SD, BODY) \ + do { \ + struct greatest_prng *prng = &greatest_info.prng[ID]; \ + greatest_prng_init_first_pass(ID); \ + do { \ + prng->count = 0; \ + if (prng->initialized) { greatest_prng_step(ID); } \ + BODY; \ + if (!prng->initialized) { \ + if (!greatest_prng_init_second_pass(ID, SD)) { break; } \ + } else if (prng->count_run == prng->count_ceil) { \ + break; \ + } \ + } while (!GREATEST_FAILURE_ABORT()); \ + prng->count_run = prng->random_order = prng->initialized = 0; \ + } while(0) + +/* Include several function definitions in the main test file. */ +#define GREATEST_MAIN_DEFS() \ + \ +/* Is FILTER a subset of NAME? */ \ +static int greatest_name_match(const char *name, const char *filter, \ + int res_if_none) { \ + size_t offset = 0; \ + size_t filter_len = filter ? strlen(filter) : 0; \ + if (filter_len == 0) { return res_if_none; } /* no filter */ \ + if (greatest_info.exact_name_match && strlen(name) != filter_len) { \ + return 0; /* ignore substring matches */ \ + } \ + while (name[offset] != '\0') { \ + if (name[offset] == filter[0]) { \ + if (0 == strncmp(&name[offset], filter, filter_len)) { \ + return 1; \ + } \ + } \ + offset++; \ + } \ + \ + return 0; \ +} \ + \ +static void greatest_buffer_test_name(const char *name) { \ + struct greatest_run_info *g = &greatest_info; \ + size_t len = strlen(name), size = sizeof(g->name_buf); \ + memset(g->name_buf, 0x00, size); \ + (void)strncat(g->name_buf, name, size - 1); \ + if (g->name_suffix && (len + 1 < size)) { \ + g->name_buf[len] = '_'; \ + strncat(&g->name_buf[len+1], g->name_suffix, size-(len+2)); \ + } \ +} \ + \ +/* Before running a test, check the name filtering and \ + * test shuffling state, if applicable, and then call setup hooks. */ \ +int greatest_test_pre(const char *name) { \ + struct greatest_run_info *g = &greatest_info; \ + int match; \ + greatest_buffer_test_name(name); \ + match = greatest_name_match(g->name_buf, g->test_filter, 1) && \ + !greatest_name_match(g->name_buf, g->test_exclude, 0); \ + if (GREATEST_LIST_ONLY()) { /* just listing test names */ \ + if (match) { \ + GREATEST_FPRINTF(GREATEST_STDOUT, " %s\n", g->name_buf); \ + } \ + goto clear; \ + } \ + if (match && (!GREATEST_FIRST_FAIL() || g->suite.failed == 0)) { \ + struct greatest_prng *p = &g->prng[1]; \ + if (p->random_order) { \ + p->count++; \ + if (!p->initialized || ((p->count - 1) != p->state)) { \ + goto clear; /* don't run this test yet */ \ + } \ + } \ + if (g->running_test) { \ + fprintf(stderr, "Error: Test run inside another test.\n"); \ + return 0; \ + } \ + GREATEST_SET_TIME(g->suite.pre_test); \ + if (g->setup) { g->setup(g->setup_udata); } \ + p->count_run++; \ + g->running_test = 1; \ + return 1; /* test should be run */ \ + } else { \ + goto clear; /* skipped */ \ + } \ +clear: \ + g->name_suffix = NULL; \ + return 0; \ +} \ + \ +static void greatest_do_pass(void) { \ + struct greatest_run_info *g = &greatest_info; \ + if (GREATEST_IS_VERBOSE()) { \ + GREATEST_FPRINTF(GREATEST_STDOUT, "PASS %s: %s", \ + g->name_buf, g->msg ? g->msg : ""); \ + } else { \ + GREATEST_FPRINTF(GREATEST_STDOUT, "."); \ + } \ + g->suite.passed++; \ +} \ + \ +static void greatest_do_fail(void) { \ + struct greatest_run_info *g = &greatest_info; \ + if (GREATEST_IS_VERBOSE()) { \ + GREATEST_FPRINTF(GREATEST_STDOUT, \ + "FAIL %s: %s (%s:%u)", g->name_buf, \ + g->msg ? g->msg : "", g->fail_file, g->fail_line); \ + } else { \ + GREATEST_FPRINTF(GREATEST_STDOUT, "F"); \ + g->col++; /* add linebreak if in line of '.'s */ \ + if (g->col != 0) { \ + GREATEST_FPRINTF(GREATEST_STDOUT, "\n"); \ + g->col = 0; \ + } \ + GREATEST_FPRINTF(GREATEST_STDOUT, "FAIL %s: %s (%s:%u)\n", \ + g->name_buf, g->msg ? g->msg : "", \ + g->fail_file, g->fail_line); \ + } \ + g->suite.failed++; \ +} \ + \ +static void greatest_do_skip(void) { \ + struct greatest_run_info *g = &greatest_info; \ + if (GREATEST_IS_VERBOSE()) { \ + GREATEST_FPRINTF(GREATEST_STDOUT, "SKIP %s: %s", \ + g->name_buf, g->msg ? g->msg : ""); \ + } else { \ + GREATEST_FPRINTF(GREATEST_STDOUT, "s"); \ + } \ + g->suite.skipped++; \ +} \ + \ +void greatest_test_post(int res) { \ + GREATEST_SET_TIME(greatest_info.suite.post_test); \ + if (greatest_info.teardown) { \ + void *udata = greatest_info.teardown_udata; \ + greatest_info.teardown(udata); \ + } \ + \ + greatest_info.running_test = 0; \ + if (res <= GREATEST_TEST_RES_FAIL) { \ + greatest_do_fail(); \ + } else if (res >= GREATEST_TEST_RES_SKIP) { \ + greatest_do_skip(); \ + } else if (res == GREATEST_TEST_RES_PASS) { \ + greatest_do_pass(); \ + } \ + greatest_info.name_suffix = NULL; \ + greatest_info.suite.tests_run++; \ + greatest_info.col++; \ + if (GREATEST_IS_VERBOSE()) { \ + GREATEST_CLOCK_DIFF(greatest_info.suite.pre_test, \ + greatest_info.suite.post_test); \ + GREATEST_FPRINTF(GREATEST_STDOUT, "\n"); \ + } else if (greatest_info.col % greatest_info.width == 0) { \ + GREATEST_FPRINTF(GREATEST_STDOUT, "\n"); \ + greatest_info.col = 0; \ + } \ + fflush(GREATEST_STDOUT); \ +} \ + \ +static void report_suite(void) { \ + if (greatest_info.suite.tests_run > 0) { \ + GREATEST_FPRINTF(GREATEST_STDOUT, \ + "\n%u test%s - %u passed, %u failed, %u skipped", \ + greatest_info.suite.tests_run, \ + greatest_info.suite.tests_run == 1 ? "" : "s", \ + greatest_info.suite.passed, \ + greatest_info.suite.failed, \ + greatest_info.suite.skipped); \ + GREATEST_CLOCK_DIFF(greatest_info.suite.pre_suite, \ + greatest_info.suite.post_suite); \ + GREATEST_FPRINTF(GREATEST_STDOUT, "\n"); \ + } \ +} \ + \ +static void update_counts_and_reset_suite(void) { \ + greatest_info.setup = NULL; \ + greatest_info.setup_udata = NULL; \ + greatest_info.teardown = NULL; \ + greatest_info.teardown_udata = NULL; \ + greatest_info.passed += greatest_info.suite.passed; \ + greatest_info.failed += greatest_info.suite.failed; \ + greatest_info.skipped += greatest_info.suite.skipped; \ + greatest_info.tests_run += greatest_info.suite.tests_run; \ + memset(&greatest_info.suite, 0, sizeof(greatest_info.suite)); \ + greatest_info.col = 0; \ +} \ + \ +static int greatest_suite_pre(const char *suite_name) { \ + struct greatest_prng *p = &greatest_info.prng[0]; \ + if (!greatest_name_match(suite_name, greatest_info.suite_filter, 1) \ + || (GREATEST_FAILURE_ABORT())) { return 0; } \ + if (p->random_order) { \ + p->count++; \ + if (!p->initialized || ((p->count - 1) != p->state)) { \ + return 0; /* don't run this suite yet */ \ + } \ + } \ + p->count_run++; \ + update_counts_and_reset_suite(); \ + GREATEST_FPRINTF(GREATEST_STDOUT, "\n* Suite %s:\n", suite_name); \ + GREATEST_SET_TIME(greatest_info.suite.pre_suite); \ + return 1; \ +} \ + \ +static void greatest_suite_post(void) { \ + GREATEST_SET_TIME(greatest_info.suite.post_suite); \ + report_suite(); \ +} \ + \ +static void greatest_run_suite(greatest_suite_cb *suite_cb, \ + const char *suite_name) { \ + if (greatest_suite_pre(suite_name)) { \ + suite_cb(); \ + greatest_suite_post(); \ + } \ +} \ + \ +int greatest_do_assert_equal_t(const void *expd, const void *got, \ + greatest_type_info *type_info, void *udata) { \ + int eq = 0; \ + if (type_info == NULL || type_info->equal == NULL) { return 0; } \ + eq = type_info->equal(expd, got, udata); \ + if (!eq) { \ + if (type_info->print != NULL) { \ + GREATEST_FPRINTF(GREATEST_STDOUT, "\nExpected: "); \ + (void)type_info->print(expd, udata); \ + GREATEST_FPRINTF(GREATEST_STDOUT, "\n Got: "); \ + (void)type_info->print(got, udata); \ + GREATEST_FPRINTF(GREATEST_STDOUT, "\n"); \ + } \ + } \ + return eq; \ +} \ + \ +static void greatest_usage(const char *name) { \ + GREATEST_FPRINTF(GREATEST_STDOUT, \ + "Usage: %s [-hlfavex] [-s SUITE] [-t TEST] [-x EXCLUDE]\n" \ + " -h, --help print this Help\n" \ + " -l List suites and tests, then exit (dry run)\n" \ + " -f Stop runner after first failure\n" \ + " -a Abort on first failure (implies -f)\n" \ + " -v Verbose output\n" \ + " -s SUITE only run suites containing substring SUITE\n" \ + " -t TEST only run tests containing substring TEST\n" \ + " -e only run exact name match for -s or -t\n" \ + " -x EXCLUDE exclude tests containing substring EXCLUDE\n", \ + name); \ +} \ + \ +static void greatest_parse_options(int argc, char **argv) { \ + int i = 0; \ + for (i = 1; i < argc; i++) { \ + if (argv[i][0] == '-') { \ + char f = argv[i][1]; \ + if ((f == 's' || f == 't' || f == 'x') && argc <= i + 1) { \ + greatest_usage(argv[0]); exit(EXIT_FAILURE); \ + } \ + switch (f) { \ + case 's': /* suite name filter */ \ + greatest_set_suite_filter(argv[i + 1]); i++; break; \ + case 't': /* test name filter */ \ + greatest_set_test_filter(argv[i + 1]); i++; break; \ + case 'x': /* test name exclusion */ \ + greatest_set_test_exclude(argv[i + 1]); i++; break; \ + case 'e': /* exact name match */ \ + greatest_set_exact_name_match(); break; \ + case 'f': /* first fail flag */ \ + greatest_stop_at_first_fail(); break; \ + case 'a': /* abort() on fail flag */ \ + greatest_abort_on_fail(); break; \ + case 'l': /* list only (dry run) */ \ + greatest_list_only(); break; \ + case 'v': /* first fail flag */ \ + greatest_info.verbosity++; break; \ + case 'h': /* help */ \ + greatest_usage(argv[0]); exit(EXIT_SUCCESS); \ + default: \ + case '-': \ + if (0 == strncmp("--help", argv[i], 6)) { \ + greatest_usage(argv[0]); exit(EXIT_SUCCESS); \ + } else if (0 == strcmp("--", argv[i])) { \ + return; /* ignore following arguments */ \ + } \ + GREATEST_FPRINTF(GREATEST_STDOUT, \ + "Unknown argument '%s'\n", argv[i]); \ + greatest_usage(argv[0]); \ + exit(EXIT_FAILURE); \ + } \ + } \ + } \ +} \ + \ +int greatest_all_passed(void) { return (greatest_info.failed == 0); } \ + \ +void greatest_set_test_filter(const char *filter) { \ + greatest_info.test_filter = filter; \ +} \ + \ +void greatest_set_test_exclude(const char *filter) { \ + greatest_info.test_exclude = filter; \ +} \ + \ +void greatest_set_suite_filter(const char *filter) { \ + greatest_info.suite_filter = filter; \ +} \ + \ +void greatest_set_exact_name_match(void) { \ + greatest_info.exact_name_match = 1; \ +} \ + \ +void greatest_stop_at_first_fail(void) { \ + greatest_set_flag(GREATEST_FLAG_FIRST_FAIL); \ +} \ + \ +void greatest_abort_on_fail(void) { \ + greatest_set_flag(GREATEST_FLAG_ABORT_ON_FAIL); \ +} \ + \ +void greatest_list_only(void) { \ + greatest_set_flag(GREATEST_FLAG_LIST_ONLY); \ +} \ + \ +void greatest_get_report(struct greatest_report_t *report) { \ + if (report) { \ + report->passed = greatest_info.passed; \ + report->failed = greatest_info.failed; \ + report->skipped = greatest_info.skipped; \ + report->assertions = greatest_info.assertions; \ + } \ +} \ + \ +unsigned int greatest_get_verbosity(void) { \ + return greatest_info.verbosity; \ +} \ + \ +void greatest_set_verbosity(unsigned int verbosity) { \ + greatest_info.verbosity = (unsigned char)verbosity; \ +} \ + \ +void greatest_set_flag(greatest_flag_t flag) { \ + greatest_info.flags = (unsigned char)(greatest_info.flags | flag); \ +} \ + \ +void greatest_set_test_suffix(const char *suffix) { \ + greatest_info.name_suffix = suffix; \ +} \ + \ +void GREATEST_SET_SETUP_CB(greatest_setup_cb *cb, void *udata) { \ + greatest_info.setup = cb; \ + greatest_info.setup_udata = udata; \ +} \ + \ +void GREATEST_SET_TEARDOWN_CB(greatest_teardown_cb *cb, void *udata) { \ + greatest_info.teardown = cb; \ + greatest_info.teardown_udata = udata; \ +} \ + \ +static int greatest_string_equal_cb(const void *expd, const void *got, \ + void *udata) { \ + size_t *size = (size_t *)udata; \ + return (size != NULL \ + ? (0 == strncmp((const char *)expd, (const char *)got, *size)) \ + : (0 == strcmp((const char *)expd, (const char *)got))); \ +} \ + \ +static int greatest_string_printf_cb(const void *t, void *udata) { \ + (void)udata; /* note: does not check \0 termination. */ \ + return GREATEST_FPRINTF(GREATEST_STDOUT, "%s", (const char *)t); \ +} \ + \ +greatest_type_info greatest_type_info_string = { \ + greatest_string_equal_cb, greatest_string_printf_cb, \ +}; \ + \ +static int greatest_memory_equal_cb(const void *expd, const void *got, \ + void *udata) { \ + greatest_memory_cmp_env *env = (greatest_memory_cmp_env *)udata; \ + return (0 == memcmp(expd, got, env->size)); \ +} \ + \ +/* Hexdump raw memory, with differences highlighted */ \ +static int greatest_memory_printf_cb(const void *t, void *udata) { \ + greatest_memory_cmp_env *env = (greatest_memory_cmp_env *)udata; \ + const unsigned char *buf = (const unsigned char *)t; \ + unsigned char diff_mark = ' '; \ + FILE *out = GREATEST_STDOUT; \ + size_t i, line_i, line_len = 0; \ + int len = 0; /* format hexdump with differences highlighted */ \ + for (i = 0; i < env->size; i+= line_len) { \ + diff_mark = ' '; \ + line_len = env->size - i; \ + if (line_len > 16) { line_len = 16; } \ + for (line_i = i; line_i < i + line_len; line_i++) { \ + if (env->exp[line_i] != env->got[line_i]) diff_mark = 'X'; \ + } \ + len += GREATEST_FPRINTF(out, "\n%04x %c ", \ + (unsigned int)i, diff_mark); \ + for (line_i = i; line_i < i + line_len; line_i++) { \ + int m = env->exp[line_i] == env->got[line_i]; /* match? */ \ + len += GREATEST_FPRINTF(out, "%02x%c", \ + buf[line_i], m ? ' ' : '<'); \ + } \ + for (line_i = 0; line_i < 16 - line_len; line_i++) { \ + len += GREATEST_FPRINTF(out, " "); \ + } \ + GREATEST_FPRINTF(out, " "); \ + for (line_i = i; line_i < i + line_len; line_i++) { \ + unsigned char c = buf[line_i]; \ + len += GREATEST_FPRINTF(out, "%c", isprint(c) ? c : '.'); \ + } \ + } \ + len += GREATEST_FPRINTF(out, "\n"); \ + return len; \ +} \ + \ +void greatest_prng_init_first_pass(int id) { \ + greatest_info.prng[id].random_order = 1; \ + greatest_info.prng[id].count_run = 0; \ +} \ + \ +int greatest_prng_init_second_pass(int id, unsigned long seed) { \ + struct greatest_prng *p = &greatest_info.prng[id]; \ + if (p->count == 0) { return 0; } \ + p->count_ceil = p->count; \ + for (p->m = 1; p->m < p->count; p->m <<= 1) {} \ + p->state = seed & 0x1fffffff; /* only use lower 29 bits */ \ + p->a = 4LU * p->state; /* to avoid overflow when */ \ + p->a = (p->a ? p->a : 4) | 1; /* multiplied by 4 */ \ + p->c = 2147483647; /* and so p->c ((2 ** 31) - 1) is */ \ + p->initialized = 1; /* always relatively prime to p->a. */ \ + fprintf(stderr, "init_second_pass: a %lu, c %lu, state %lu\n", \ + p->a, p->c, p->state); \ + return 1; \ +} \ + \ +/* Step the pseudorandom number generator until its state reaches \ + * another test ID between 0 and the test count. \ + * This use a linear congruential pseudorandom number generator, \ + * with the power-of-two ceiling of the test count as the modulus, the \ + * masked seed as the multiplier, and a prime as the increment. For \ + * each generated value < the test count, run the corresponding test. \ + * This will visit all IDs 0 <= X < mod once before repeating, \ + * with a starting position chosen based on the initial seed. \ + * For details, see: Knuth, The Art of Computer Programming \ + * Volume. 2, section 3.2.1. */ \ +void greatest_prng_step(int id) { \ + struct greatest_prng *p = &greatest_info.prng[id]; \ + do { \ + p->state = ((p->a * p->state) + p->c) & (p->m - 1); \ + } while (p->state >= p->count_ceil); \ +} \ + \ +void GREATEST_INIT(void) { \ + /* Suppress unused function warning if features aren't used */ \ + (void)greatest_run_suite; \ + (void)greatest_parse_options; \ + (void)greatest_prng_step; \ + (void)greatest_prng_init_first_pass; \ + (void)greatest_prng_init_second_pass; \ + (void)greatest_set_test_suffix; \ + \ + memset(&greatest_info, 0, sizeof(greatest_info)); \ + greatest_info.width = GREATEST_DEFAULT_WIDTH; \ + GREATEST_SET_TIME(greatest_info.begin); \ +} \ + \ +/* Report passes, failures, skipped tests, the number of \ + * assertions, and the overall run time. */ \ +void GREATEST_PRINT_REPORT(void) { \ + if (!GREATEST_LIST_ONLY()) { \ + update_counts_and_reset_suite(); \ + GREATEST_SET_TIME(greatest_info.end); \ + GREATEST_FPRINTF(GREATEST_STDOUT, \ + "\nTotal: %u test%s", \ + greatest_info.tests_run, \ + greatest_info.tests_run == 1 ? "" : "s"); \ + GREATEST_CLOCK_DIFF(greatest_info.begin, \ + greatest_info.end); \ + GREATEST_FPRINTF(GREATEST_STDOUT, ", %u assertion%s\n", \ + greatest_info.assertions, \ + greatest_info.assertions == 1 ? "" : "s"); \ + GREATEST_FPRINTF(GREATEST_STDOUT, \ + "Pass: %u, fail: %u, skip: %u.\n", \ + greatest_info.passed, \ + greatest_info.failed, greatest_info.skipped); \ + } \ +} \ + \ +greatest_type_info greatest_type_info_memory = { \ + greatest_memory_equal_cb, greatest_memory_printf_cb, \ +}; \ + \ +greatest_run_info greatest_info + +/* Handle command-line arguments, etc. */ +#define GREATEST_MAIN_BEGIN() \ + do { \ + GREATEST_INIT(); \ + greatest_parse_options(argc, argv); \ + } while (0) + +/* Report results, exit with exit status based on results. */ +#define GREATEST_MAIN_END() \ + do { \ + GREATEST_PRINT_REPORT(); \ + return (greatest_all_passed() ? EXIT_SUCCESS : EXIT_FAILURE); \ + } while (0) + +/* Make abbreviations without the GREATEST_ prefix for the + * most commonly used symbols. */ +#if GREATEST_USE_ABBREVS +#define TEST GREATEST_TEST +#define SUITE GREATEST_SUITE +#define SUITE_EXTERN GREATEST_SUITE_EXTERN +#define RUN_TEST GREATEST_RUN_TEST +#define RUN_TEST1 GREATEST_RUN_TEST1 +#define RUN_SUITE GREATEST_RUN_SUITE +#define IGNORE_TEST GREATEST_IGNORE_TEST +#define ASSERT GREATEST_ASSERT +#define ASSERTm GREATEST_ASSERTm +#define ASSERT_FALSE GREATEST_ASSERT_FALSE +#define ASSERT_EQ GREATEST_ASSERT_EQ +#define ASSERT_NEQ GREATEST_ASSERT_NEQ +#define ASSERT_GT GREATEST_ASSERT_GT +#define ASSERT_GTE GREATEST_ASSERT_GTE +#define ASSERT_LT GREATEST_ASSERT_LT +#define ASSERT_LTE GREATEST_ASSERT_LTE +#define ASSERT_EQ_FMT GREATEST_ASSERT_EQ_FMT +#define ASSERT_IN_RANGE GREATEST_ASSERT_IN_RANGE +#define ASSERT_EQUAL_T GREATEST_ASSERT_EQUAL_T +#define ASSERT_STR_EQ GREATEST_ASSERT_STR_EQ +#define ASSERT_STRN_EQ GREATEST_ASSERT_STRN_EQ +#define ASSERT_MEM_EQ GREATEST_ASSERT_MEM_EQ +#define ASSERT_ENUM_EQ GREATEST_ASSERT_ENUM_EQ +#define ASSERT_FALSEm GREATEST_ASSERT_FALSEm +#define ASSERT_EQm GREATEST_ASSERT_EQm +#define ASSERT_NEQm GREATEST_ASSERT_NEQm +#define ASSERT_GTm GREATEST_ASSERT_GTm +#define ASSERT_GTEm GREATEST_ASSERT_GTEm +#define ASSERT_LTm GREATEST_ASSERT_LTm +#define ASSERT_LTEm GREATEST_ASSERT_LTEm +#define ASSERT_EQ_FMTm GREATEST_ASSERT_EQ_FMTm +#define ASSERT_IN_RANGEm GREATEST_ASSERT_IN_RANGEm +#define ASSERT_EQUAL_Tm GREATEST_ASSERT_EQUAL_Tm +#define ASSERT_STR_EQm GREATEST_ASSERT_STR_EQm +#define ASSERT_STRN_EQm GREATEST_ASSERT_STRN_EQm +#define ASSERT_MEM_EQm GREATEST_ASSERT_MEM_EQm +#define ASSERT_ENUM_EQm GREATEST_ASSERT_ENUM_EQm +#define PASS GREATEST_PASS +#define FAIL GREATEST_FAIL +#define SKIP GREATEST_SKIP +#define PASSm GREATEST_PASSm +#define FAILm GREATEST_FAILm +#define SKIPm GREATEST_SKIPm +#define SET_SETUP GREATEST_SET_SETUP_CB +#define SET_TEARDOWN GREATEST_SET_TEARDOWN_CB +#define CHECK_CALL GREATEST_CHECK_CALL +#define SHUFFLE_TESTS GREATEST_SHUFFLE_TESTS +#define SHUFFLE_SUITES GREATEST_SHUFFLE_SUITES + +#ifdef GREATEST_VA_ARGS +#define RUN_TESTp GREATEST_RUN_TESTp +#endif + +#if GREATEST_USE_LONGJMP +#define ASSERT_OR_LONGJMP GREATEST_ASSERT_OR_LONGJMP +#define ASSERT_OR_LONGJMPm GREATEST_ASSERT_OR_LONGJMPm +#define FAIL_WITH_LONGJMP GREATEST_FAIL_WITH_LONGJMP +#define FAIL_WITH_LONGJMPm GREATEST_FAIL_WITH_LONGJMPm +#endif + +#endif /* USE_ABBREVS */ + +#if defined(__cplusplus) && !defined(GREATEST_NO_EXTERN_CPLUSPLUS) +} +#endif + +#endif \ No newline at end of file diff --git a/framework/3rd/minicoro/minicoro.h b/framework/3rd/minicoro/minicoro.h new file mode 100644 index 0000000..9403a2d --- /dev/null +++ b/framework/3rd/minicoro/minicoro.h @@ -0,0 +1,1789 @@ +/* +Minimal asymmetric stackful cross-platform coroutine library in pure C. +minicoro - v0.1.2 - 13/Feb/2021 +Eduardo Bart - edub4rt@gmail.com +https://github.com/edubart/minicoro + +Minicoro is single file library for using asymmetric coroutines in C. +The API is inspired by Lua coroutines but with C use in mind. + +# Features + +- Stackful asymmetric coroutines. +- Supports nesting coroutines (resuming a coroutine from another coroutine). +- Supports custom allocators. +- Storage system to allow passing values between yield and resume. +- Customizable stack size. +- Coroutine API design inspired by Lua with C use in mind. +- Yield across any C function. +- Made to work in multithread applications. +- Cross platform. +- Minimal, self contained and no external dependencies. +- Readable sources and documented. +- Implemented via assembly, ucontext or fibers. +- Lightweight and very efficient. +- Works in most C89 compilers. +- Error prone API, returning proper error codes on misuse. +- Support running with Valgrind, ASan (AddressSanitizer) and TSan (ThreadSanitizer). + +# Supported Platforms + +Most platforms are supported through different methods: + +| Platform | Assembly Method | Fallback Method | +|--------------|------------------|-------------------| +| Android | ARM/ARM64 | N/A | +| Windows | x86_64 | Windows fibers | +| Linux | x86_64/i686 | ucontext | +| Mac OS X | x86_64 | ucontext | +| Browser | N/A | Emscripten fibers | +| Raspberry Pi | ARM | ucontext | +| RISC-V | rv64/rv32 | ucontext | + +The assembly method is used by default if supported by the compiler and CPU, +otherwise ucontext or fiber method is used as a fallback. + +The assembly method is very efficient, it just take a few cycles +to create, resume, yield or destroy a coroutine. + +# Caveats + +- Don't use coroutines with C++ exceptions, this is not supported. +- When using C++ RAII (i.e. destructors) you must resume the coroutine until it dies to properly execute all destructors. +- To use in multithread applications, you must compile with C compiler that supports `thread_local` qualifier. +- Some unsupported sanitizers for C may trigger false warnings when using coroutines. +- The `mco_coro` object is not thread safe, you should lock each coroutine into a thread. +- Stack space is fixed, it cannot grow. By default it has about 56KB of space, this can be changed on coroutine creation. +- Take care to not cause stack overflows (run out of stack space), otherwise your program may crash or not, the behavior is undefined. +- On WebAssembly you must compile with emscripten flag `-s ASYNCIFY=1`. + +# Introduction + +A coroutine represents an independent "green" thread of execution. +Unlike threads in multithread systems, however, +a coroutine only suspends its execution by explicitly calling a yield function. + +You create a coroutine by calling `mco_create`. +Its sole argument is a `mco_desc` structure with a description for the coroutine. +The `mco_create` function only creates a new coroutine and returns a handle to it, it does not start the coroutine. + +You execute a coroutine by calling `mco_resume`. +When calling a resume function the coroutine starts its execution by calling its body function. +After the coroutine starts running, it runs until it terminates or yields. + +A coroutine yields by calling `mco_yield`. +When a coroutine yields, the corresponding resume returns immediately, +even if the yield happens inside nested function calls (that is, not in the main function). +The next time you resume the same coroutine, it continues its execution from the point where it yielded. + +To associate a persistent value with the coroutine, +you can optionally set `user_data` on its creation and later retrieve with `mco_get_user_data`. + +To pass values between resume and yield, +you can optionally use `mco_push` and `mco_pop` APIs, +they are intended to pass temporary values using a LIFO style buffer. +The storage system can also be used to send and receive initial values on coroutine creation or before it finishes. + +# Usage + +To use minicoro, do the following in one .c file: + +```c +#define MINICORO_IMPL +#include "minicoro.h" +``` + +You can do `#include "minicoro.h"` in other parts of the program just like any other header. + +## Minimal Example + +The following simple example demonstrates on how to use the library: + +```c +#define MINICORO_IMPL +#include "minicoro.h" +#include + +// Coroutine entry function. +void coro_entry(mco_coro* co) { + printf("coroutine 1\n"); + mco_yield(co); + printf("coroutine 2\n"); +} + +int main() { + // First initialize a `desc` object through `mco_desc_init`. + mco_desc desc = mco_desc_init(coro_entry, 0); + // Configure `desc` fields when needed (e.g. customize user_data or allocation functions). + desc.user_data = NULL; + // Call `mco_create` with the output coroutine pointer and `desc` pointer. + mco_coro* co; + mco_result res = mco_create(&co, &desc); + assert(res == MCO_SUCCESS); + // The coroutine should be now in suspended state. + assert(mco_status(co) == MCO_SUSPENDED); + // Call `mco_resume` to start for the first time, switching to its context. + res = mco_resume(co); // Should print "coroutine 1". + assert(res == MCO_SUCCESS); + // We get back from coroutine context in suspended state (because it's unfinished). + assert(mco_status(co) == MCO_SUSPENDED); + // Call `mco_resume` to resume for a second time. + res = mco_resume(co); // Should print "coroutine 2". + assert(res == MCO_SUCCESS); + // The coroutine finished and should be now dead. + assert(mco_status(co) == MCO_DEAD); + // Call `mco_destroy` to destroy the coroutine. + res = mco_destroy(co); + assert(res == MCO_SUCCESS); + return 0; +} +``` + +_NOTE_: In case you don't want to use the minicoro allocator system you should +allocate a coroutine object yourself using `mco_desc.coro_size` and call `mco_init`, +then later to destroy call `mco_deinit` and deallocate it. + +## Yielding from anywhere + +You can yield the current running coroutine from anywhere +without having to pass `mco_coro` pointers around, +to this just use `mco_yield(mco_running())`. + +## Passing data between yield and resume + +The library has the storage interface to assist passing data between yield and resume. +It's usage is straightforward, +use `mco_push` to send data before a `mco_resume` or `mco_yield`, +then later use `mco_pop` after a `mco_resume` or `mco_yield` to receive data. +Take care to not mismatch a push and pop, otherwise these functions will return +an error. + +## Error handling + +The library return error codes in most of its API in case of misuse or system error, +the user is encouraged to handle them properly. + +## Library customization + +The following can be defined to change the library behavior: + +- `MCO_API` - Public API qualifier. Default is `extern`. +- `MCO_MIN_STACK_SIZE` - Minimum stack size when creating a coroutine. Default is 32768. +- `MCO_DEFAULT_STORAGE_SIZE` - Size of coroutine storage buffer. Default is 1024. +- `MCO_DEFAULT_STACK_SIZE` - Default stack size when creating a coroutine. Default is 57344. +- `MCO_MALLOC` - Default allocation function. Default is `malloc`. +- `MCO_FREE` - Default deallocation function. Default is `free`. +- `MCO_DEBUG` - Enable debug mode, logging any runtime error to stdout. Defined automatically unless `NDEBUG` or `MCO_NO_DEBUG` is defined. +- `MCO_NO_DEBUG` - Disable debug mode. +- `MCO_NO_MULTITHREAD` - Disable multithread usage. Multithread is supported when `thread_local` is supported. +- `MCO_NO_DEFAULT_ALLOCATORS` - Disable the default allocator using `MCO_MALLOC` and `MCO_FREE`. +- `MCO_ZERO_MEMORY` - Zero memory of stack for new coroutines and when poping storage, intended for garbage collected environments. +- `MCO_USE_ASM` - Force use of assembly context switch implementation. +- `MCO_USE_UCONTEXT` - Force use of ucontext context switch implementation. +- `MCO_USE_FIBERS` - Force use of fibers context switch implementation. +- `MCO_USE_VALGRIND` - Define if you want run with valgrind to fix accessing memory errors. + +# License + +Your choice of either Public Domain or MIT No Attribution, see end of file. +*/ + + +#ifndef MINICORO_H +#define MINICORO_H + +#ifdef __cplusplus +extern "C" { +#endif + +/* Public API qualifier. */ +#ifndef MCO_API +#define MCO_API extern +#endif + +/* Size of coroutine storage buffer. */ +#ifndef MCO_DEFAULT_STORAGE_SIZE +#define MCO_DEFAULT_STORAGE_SIZE 1024 +#endif + +#include /* for size_t */ + +/* ---------------------------------------------------------------------------------------------- */ + +/* Coroutine states. */ +typedef enum mco_state { + MCO_DEAD = 0, /* The coroutine has finished normally or was uninitialized before finishing. */ + MCO_NORMAL, /* The coroutine is active but not running (that is, it has resumed another coroutine). */ + MCO_RUNNING, /* The coroutine is active and running. */ + MCO_SUSPENDED, /* The coroutine is suspended (in a call to yield, or it has not started running yet). */ +} mco_state; + +/* Coroutine result codes. */ +typedef enum mco_result { + MCO_SUCCESS = 0, + MCO_GENERIC_ERROR, + MCO_INVALID_POINTER, + MCO_INVALID_COROUTINE, + MCO_NOT_SUSPENDED, + MCO_NOT_RUNNING, + MCO_MAKE_CONTEXT_ERROR, + MCO_SWITCH_CONTEXT_ERROR, + MCO_NOT_ENOUGH_SPACE, + MCO_OUT_OF_MEMORY, + MCO_INVALID_ARGUMENTS, + MCO_INVALID_OPERATION, +} mco_result; + +/* Coroutine structure. */ +typedef struct mco_coro mco_coro; +struct mco_coro { + void* context; + mco_state state; + void (*func)(mco_coro* co); + mco_coro* prev_co; + void* user_data; + void* allocator_data; + void (*free_cb)(void* ptr, void* allocator_data); + void* stack_base; /* Stack base address, can be used to scan memory in a garbage collector. */ + size_t stack_size; + unsigned char* storage; + size_t bytes_stored; + size_t storage_size; + void* asan_prev_stack; /* Used by address sanitizer. */ + void* tsan_prev_fiber; /* Used by thread sanitizer. */ + void* tsan_fiber; /* Used by thread sanitizer. */ +}; + +/* Structure used to initialize a coroutine. */ +typedef struct mco_desc { + void (*func)(mco_coro* co); /* Entry point function for the coroutine. */ + void* user_data; /* Coroutine user data, can be get with `mco_get_user_data`. */ + /* Custom allocation interface. */ + void* (*malloc_cb)(size_t size, void* allocator_data); /* Custom allocation function. */ + void (*free_cb)(void* ptr, void* allocator_data); /* Custom deallocation function. */ + void* allocator_data; /* User data pointer passed to `malloc`/`free` allocation functions. */ + size_t storage_size; /* Coroutine storage size, to be used with the storage APIs. */ + /* These must be initialized only through `mco_init_desc`. */ + size_t coro_size; /* Coroutine structure size. */ + size_t stack_size; /* Coroutine stack size. */ +} mco_desc; + +/* Coroutine functions. */ +MCO_API mco_desc mco_desc_init(void (*func)(mco_coro* co), size_t stack_size); /* Initialize description of a coroutine. When stack size is 0 then MCO_DEFAULT_STACK_SIZE is used. */ +MCO_API mco_result mco_init(mco_coro* co, mco_desc* desc); /* Initialize the coroutine. */ +MCO_API mco_result mco_uninit(mco_coro* co); /* Uninitialize the coroutine, may fail if it's not dead or suspended. */ +MCO_API mco_result mco_create(mco_coro** out_co, mco_desc* desc); /* Allocates and initializes a new coroutine. */ +MCO_API mco_result mco_destroy(mco_coro* co); /* Uninitialize and deallocate the coroutine, may fail if it's not dead or suspended. */ +MCO_API mco_result mco_resume(mco_coro* co); /* Starts or continues the execution of the coroutine. */ +MCO_API mco_result mco_yield(mco_coro* co); /* Suspends the execution of a coroutine. */ +MCO_API mco_state mco_status(mco_coro* co); /* Returns the status of the coroutine. */ +MCO_API void* mco_get_user_data(mco_coro* co); /* Get coroutine user data supplied on coroutine creation. */ + +/* Storage interface functions, used to pass values between yield and resume. */ +MCO_API mco_result mco_push(mco_coro* co, const void* src, size_t len); /* Push bytes to the coroutine storage. Use to send values between yield and resume. */ +MCO_API mco_result mco_pop(mco_coro* co, void* dest, size_t len); /* Pop bytes from the coroutine storage. Use to get values between yield and resume. */ +MCO_API mco_result mco_peek(mco_coro* co, void* dest, size_t len); /* Like `mco_pop` but it does not consumes the storage. */ +MCO_API size_t mco_get_bytes_stored(mco_coro* co); /* Get the available bytes that can be retrieved with a `mco_pop`. */ +MCO_API size_t mco_get_storage_size(mco_coro* co); /* Get the total storage size. */ + +/* Misc functions. */ +MCO_API mco_coro* mco_running(void); /* Returns the running coroutine for the current thread. */ +MCO_API const char* mco_result_description(mco_result res); /* Get the description of a result. */ + +#ifdef __cplusplus +} +#endif + +#endif /* MINICORO_H */ + +#ifdef MINICORO_IMPL + +#ifdef __cplusplus +extern "C" { +#endif + +/* ---------------------------------------------------------------------------------------------- */ + +/* Minimum stack size when creating a coroutine. */ +#ifndef MCO_MIN_STACK_SIZE +#define MCO_MIN_STACK_SIZE 32768 +#endif + +/* Default stack size when creating a coroutine. */ +#ifndef MCO_DEFAULT_STACK_SIZE +#define MCO_DEFAULT_STACK_SIZE 57344 /* Don't use multiples of 64K to avoid D-cache aliasing conflicts. */ +#endif + +/* Detect implementation based on OS, arch and compiler. */ +#if !defined(MCO_USE_UCONTEXT) && !defined(MCO_USE_FIBERS) && !defined(MCO_USE_ASM) + #if defined(_WIN32) + #if (defined(__GNUC__) && defined(__x86_64__)) || (defined(_MSC_VER) && defined(_M_X64)) + #define MCO_USE_ASM + #else + #define MCO_USE_FIBERS + #endif + #elif defined(__EMSCRIPTEN__) + #define MCO_USE_FIBERS + #else + #if __GNUC__ >= 3 /* Assembly extension supported. */ + #if defined(__x86_64__) || \ + defined(__i386) || defined(__i386__) || \ + defined(__ARM_EABI__) || defined(__aarch64__) || \ + defined(__riscv) + #define MCO_USE_ASM + #else + #define MCO_USE_UCONTEXT + #endif + #else + #define MCO_USE_UCONTEXT + #endif + #endif +#endif + +#define _MCO_UNUSED(x) (void)(x) + +#if !defined(MCO_NO_DEBUG) && !defined(NDEBUG) && !defined(MCO_DEBUG) +#define MCO_DEBUG +#endif + +#ifndef MCO_LOG + #ifdef MCO_DEBUG + #include + #define MCO_LOG(s) puts(s) + #else + #define MCO_LOG(s) + #endif +#endif + +#ifndef MCO_ASSERT + #ifdef MCO_DEBUG + #include + #define MCO_ASSERT(c) assert(c) + #else + #define MCO_ASSERT(c) + #endif +#endif + +#ifndef MCO_THREAD_LOCAL + #ifdef MCO_NO_MULTITHREAD + #define MCO_THREAD_LOCAL + #else + #ifdef thread_local + #define MCO_THREAD_LOCAL thread_local + #elif __STDC_VERSION__ >= 201112 && !defined(__STDC_NO_THREADS__) + #define MCO_THREAD_LOCAL _Thread_local + #elif defined(_WIN32) && (defined(_MSC_VER) || defined(__ICL) || defined(__DMC__) || defined(__BORLANDC__)) + #define MCO_THREAD_LOCAL __declspec(thread) + #elif defined(__GNUC__) || defined(__SUNPRO_C) || defined(__xlC__) + #define MCO_THREAD_LOCAL __thread + #else /* No thread local support, `mco_running` will be thread unsafe. */ + #define MCO_THREAD_LOCAL + #define MCO_NO_MULTITHREAD + #endif + #endif +#endif + +#ifndef MCO_FORCE_INLINE + #ifdef _MSC_VER + #define MCO_FORCE_INLINE __forceinline + #elif defined(__GNUC__) + #if defined(__STRICT_ANSI__) + #define MCO_FORCE_INLINE __inline__ __attribute__((always_inline)) + #else + #define MCO_FORCE_INLINE inline __attribute__((always_inline)) + #endif + #elif defined(__BORLANDC__) || defined(__DMC__) || defined(__SC__) || defined(__WATCOMC__) || defined(__LCC__) || defined(__DECC) + #define MCO_FORCE_INLINE __inline + #else /* No inline support. */ + #define MCO_FORCE_INLINE + #endif +#endif + +#ifndef MCO_NO_DEFAULT_ALLOCATORS +#ifndef MCO_MALLOC + #include + #define MCO_MALLOC malloc + #define MCO_FREE free +#endif +static void* mco_malloc(size_t size, void* allocator_data) { + _MCO_UNUSED(allocator_data); + return MCO_MALLOC(size); +} +static void mco_free(void* ptr, void* allocator_data) { + _MCO_UNUSED(allocator_data); + MCO_FREE(ptr); +} +#endif /* MCO_NO_DEFAULT_ALLOCATORS */ + +#if defined(__has_feature) + #if __has_feature(address_sanitizer) + #define _MCO_USE_ASAN + #endif + #if __has_feature(thread_sanitizer) + #define _MCO_USE_TSAN + #endif +#endif +#if defined(__SANITIZE_ADDRESS__) + #define _MCO_USE_ASAN +#endif +#if defined(__SANITIZE_THREAD__) + #define _MCO_USE_TSAN +#endif +#ifdef _MCO_USE_ASAN +void __sanitizer_start_switch_fiber(void** fake_stack_save, const void *bottom, size_t size); +void __sanitizer_finish_switch_fiber(void* fake_stack_save, const void **bottom_old, size_t *size_old); +#endif +#ifdef _MCO_USE_TSAN +void* __tsan_get_current_fiber(void); +void* __tsan_create_fiber(unsigned flags); +void __tsan_destroy_fiber(void* fiber); +void __tsan_switch_to_fiber(void* fiber, unsigned flags); +#endif + +#include /* For memcpy and memset. */ + +/* Utility for aligning addresses. */ +static MCO_FORCE_INLINE size_t _mco_align_forward(size_t addr, size_t align) { + return (addr + (align-1)) & ~(align-1); +} + +/* Variable holding the current running coroutine per thread. */ +static MCO_THREAD_LOCAL mco_coro* mco_current_co = NULL; + +static MCO_FORCE_INLINE void _mco_prepare_jumpin(mco_coro* co) { + /* Set the old coroutine to normal state and update it. */ + mco_coro* prev_co = mco_running(); /* Must access through `mco_running`. */ + MCO_ASSERT(co->prev_co == NULL); + co->prev_co = prev_co; + if(prev_co) { + MCO_ASSERT(prev_co->state == MCO_RUNNING); + prev_co->state = MCO_NORMAL; + } + mco_current_co = co; +#ifdef _MCO_USE_ASAN + if(prev_co) { + void* bottom_old = NULL; + size_t size_old = 0; + __sanitizer_finish_switch_fiber(prev_co->asan_prev_stack, (const void**)&bottom_old, &size_old); + prev_co->asan_prev_stack = NULL; + } + __sanitizer_start_switch_fiber(&co->asan_prev_stack, co->stack_base, co->stack_size); +#endif +#ifdef _MCO_USE_TSAN + co->tsan_prev_fiber = __tsan_get_current_fiber(); + __tsan_switch_to_fiber(co->tsan_fiber, 0); +#endif +} + +static MCO_FORCE_INLINE void _mco_prepare_jumpout(mco_coro* co) { + /* Switch back to the previous running coroutine. */ + MCO_ASSERT(mco_running() == co); + mco_coro* prev_co = co->prev_co; + co->prev_co = NULL; + if(prev_co) { + MCO_ASSERT(prev_co->state == MCO_NORMAL); + prev_co->state = MCO_RUNNING; + } + mco_current_co = prev_co; +#ifdef _MCO_USE_ASAN + void* bottom_old = NULL; + size_t size_old = 0; + __sanitizer_finish_switch_fiber(co->asan_prev_stack, (const void**)&bottom_old, &size_old); + co->asan_prev_stack = NULL; + if(prev_co) { + __sanitizer_start_switch_fiber(&prev_co->asan_prev_stack, bottom_old, size_old); + } +#endif +#ifdef _MCO_USE_TSAN + void* tsan_prev_fiber = co->tsan_prev_fiber; + co->tsan_prev_fiber = NULL; + __tsan_switch_to_fiber(tsan_prev_fiber, 0); +#endif +} + +static void _mco_jumpin(mco_coro* co); +static void _mco_jumpout(mco_coro* co); + +static void _mco_main(mco_coro* co) { + co->func(co); /* Run the coroutine function. */ + co->state = MCO_DEAD; /* Coroutine finished successfully, set state to dead. */ + _mco_jumpout(co); /* Jump back to the old context .*/ +} + +/* ---------------------------------------------------------------------------------------------- */ + +#if defined(MCO_USE_UCONTEXT) || defined(MCO_USE_ASM) + +/* +Some of the following assembly code is taken from LuaCoco by Mike Pall. +See https://coco.luajit.org/index.html + +MIT license + +Copyright (C) 2004-2016 Mike Pall. All rights reserved. + +Permission is hereby granted, free of charge, to any person obtaining +a copy of this software and associated documentation files (the +"Software"), to deal in the Software without restriction, including +without limitation the rights to use, copy, modify, merge, publish, +distribute, sublicense, and/or sell copies of the Software, and to +permit persons to whom the Software is furnished to do so, subject to +the following conditions: + +The above copyright notice and this permission notice shall be +included in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. +IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY +CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, +TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE +SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +*/ + +#ifdef MCO_USE_ASM + +#if defined(__x86_64__) || defined(_M_X64) + +#ifdef _WIN32 + +typedef struct _mco_ctxbuf { + void *rip, *rsp, *rbp, *rbx, *r12, *r13, *r14, *r15, *rdi, *rsi; + void* xmm[20]; /* xmm6, xmm7, xmm8, xmm9, xmm10, xmm11, xmm12, xmm13, xmm14, xmm15 */ + void* fiber_storage; + void* dealloc_stack; + void* stack_limit; + void* stack_base; +} _mco_ctxbuf; + +#ifdef __GNUC__ +#define _MCO_ASM_BLOB __attribute__((section(".text#"))) +#elif defined(_MSC_VER) +#define _MCO_ASM_BLOB __declspec(allocate(".text")) +#pragma section(".text") +#endif + +_MCO_ASM_BLOB static unsigned char _mco_wrap_main_code[] = { + 0x4c, 0x89, 0xe9, /* mov %r13,%rcx */ + 0x41, 0xff, 0xe4, /* jmpq *%r12 */ + 0xc3, /* retq */ + 0x90, 0x90, 0x90, 0x90, 0x90, 0x90, 0x90, 0x90, 0x90 /* nop */ +}; + +_MCO_ASM_BLOB static unsigned char _mco_switch_code[] = { + 0x48, 0x8d, 0x05, 0x52, 0x01, 0x00, 0x00, /* lea 0x152(%rip),%rax */ + 0x48, 0x89, 0x01, /* mov %rax,(%rcx) */ + 0x48, 0x89, 0x61, 0x08, /* mov %rsp,0x8(%rcx) */ + 0x48, 0x89, 0x69, 0x10, /* mov %rbp,0x10(%rcx) */ + 0x48, 0x89, 0x59, 0x18, /* mov %rbx,0x18(%rcx) */ + 0x4c, 0x89, 0x61, 0x20, /* mov %r12,0x20(%rcx) */ + 0x4c, 0x89, 0x69, 0x28, /* mov %r13,0x28(%rcx) */ + 0x4c, 0x89, 0x71, 0x30, /* mov %r14,0x30(%rcx) */ + 0x4c, 0x89, 0x79, 0x38, /* mov %r15,0x38(%rcx) */ + 0x48, 0x89, 0x79, 0x40, /* mov %rdi,0x40(%rcx) */ + 0x48, 0x89, 0x71, 0x48, /* mov %rsi,0x48(%rcx) */ + 0x66, 0x0f, 0xd6, 0x71, 0x50, /* movq %xmm6,0x50(%rcx) */ + 0x66, 0x0f, 0xd6, 0x79, 0x60, /* movq %xmm7,0x60(%rcx) */ + 0x66, 0x44, 0x0f, 0xd6, 0x41, 0x70, /* movq %xmm8,0x70(%rcx) */ + 0x66, 0x44, 0x0f, 0xd6, 0x89, 0x80, 0x00, 0x00, 0x00, /* movq %xmm9,0x80(%rcx) */ + 0x66, 0x44, 0x0f, 0xd6, 0x91, 0x90, 0x00, 0x00, 0x00, /* movq %xmm10,0x90(%rcx) */ + 0x66, 0x44, 0x0f, 0xd6, 0x99, 0xa0, 0x00, 0x00, 0x00, /* movq %xmm11,0xa0(%rcx) */ + 0x66, 0x44, 0x0f, 0xd6, 0xa1, 0xb0, 0x00, 0x00, 0x00, /* movq %xmm12,0xb0(%rcx) */ + 0x66, 0x44, 0x0f, 0xd6, 0xa9, 0xc0, 0x00, 0x00, 0x00, /* movq %xmm13,0xc0(%rcx) */ + 0x66, 0x44, 0x0f, 0xd6, 0xb1, 0xd0, 0x00, 0x00, 0x00, /* movq %xmm14,0xd0(%rcx) */ + 0x66, 0x44, 0x0f, 0xd6, 0xb9, 0xe0, 0x00, 0x00, 0x00, /* movq %xmm15,0xe0(%rcx) */ + 0x65, 0x4c, 0x8b, 0x14, 0x25, 0x30, 0x00, 0x00, 0x00, /* mov %gs:0x30,%r10 */ + 0x49, 0x8b, 0x42, 0x20, /* mov 0x20(%r10),%rax */ + 0x48, 0x89, 0x81, 0xf0, 0x00, 0x00, 0x00, /* mov %rax,0xf0(%rcx) */ + 0x49, 0x8b, 0x82, 0x78, 0x14, 0x00, 0x00, /* mov 0x1478(%r10),%rax */ + 0x48, 0x89, 0x81, 0xf8, 0x00, 0x00, 0x00, /* mov %rax,0xf8(%rcx) */ + 0x49, 0x8b, 0x42, 0x10, /* mov 0x10(%r10),%rax */ + 0x48, 0x89, 0x81, 0x00, 0x01, 0x00, 0x00, /* mov %rax,0x100(%rcx) */ + 0x49, 0x8b, 0x42, 0x08, /* mov 0x8(%r10),%rax */ + 0x48, 0x89, 0x81, 0x08, 0x01, 0x00, 0x00, /* mov %rax,0x108(%rcx) */ + 0x48, 0x8b, 0x82, 0x08, 0x01, 0x00, 0x00, /* mov 0x108(%rdx),%rax */ + 0x49, 0x89, 0x42, 0x08, /* mov %rax,0x8(%r10) */ + 0x48, 0x8b, 0x82, 0x00, 0x01, 0x00, 0x00, /* mov 0x100(%rdx),%rax */ + 0x49, 0x89, 0x42, 0x10, /* mov %rax,0x10(%r10) */ + 0x48, 0x8b, 0x82, 0xf8, 0x00, 0x00, 0x00, /* mov 0xf8(%rdx),%rax */ + 0x49, 0x89, 0x82, 0x78, 0x14, 0x00, 0x00, /* mov %rax,0x1478(%r10) */ + 0x48, 0x8b, 0x82, 0xf0, 0x00, 0x00, 0x00, /* mov 0xf0(%rdx),%rax */ + 0x49, 0x89, 0x42, 0x20, /* mov %rax,0x20(%r10) */ + 0xf3, 0x44, 0x0f, 0x7e, 0xba, 0xe0, 0x00, 0x00, 0x00, /* movq 0xe0(%rdx),%xmm15 */ + 0xf3, 0x44, 0x0f, 0x7e, 0xb2, 0xd0, 0x00, 0x00, 0x00, /* movq 0xd0(%rdx),%xmm14 */ + 0xf3, 0x44, 0x0f, 0x7e, 0xaa, 0xc0, 0x00, 0x00, 0x00, /* movq 0xc0(%rdx),%xmm13 */ + 0xf3, 0x44, 0x0f, 0x7e, 0xa2, 0xb0, 0x00, 0x00, 0x00, /* movq 0xb0(%rdx),%xmm12 */ + 0xf3, 0x44, 0x0f, 0x7e, 0x9a, 0xa0, 0x00, 0x00, 0x00, /* movq 0xa0(%rdx),%xmm11 */ + 0xf3, 0x44, 0x0f, 0x7e, 0x92, 0x90, 0x00, 0x00, 0x00, /* movq 0x90(%rdx),%xmm10 */ + 0xf3, 0x44, 0x0f, 0x7e, 0x8a, 0x80, 0x00, 0x00, 0x00, /* movq 0x80(%rdx),%xmm9 */ + 0xf3, 0x44, 0x0f, 0x7e, 0x42, 0x70, /* movq 0x70(%rdx),%xmm8 */ + 0xf3, 0x0f, 0x7e, 0x7a, 0x60, /* movq 0x60(%rdx),%xmm7 */ + 0xf3, 0x0f, 0x7e, 0x72, 0x50, /* movq 0x50(%rdx),%xmm6 */ + 0x48, 0x8b, 0x72, 0x48, /* mov 0x48(%rdx),%rsi */ + 0x48, 0x8b, 0x7a, 0x40, /* mov 0x40(%rdx),%rdi */ + 0x4c, 0x8b, 0x7a, 0x38, /* mov 0x38(%rdx),%r15 */ + 0x4c, 0x8b, 0x72, 0x30, /* mov 0x30(%rdx),%r14 */ + 0x4c, 0x8b, 0x6a, 0x28, /* mov 0x28(%rdx),%r13 */ + 0x4c, 0x8b, 0x62, 0x20, /* mov 0x20(%rdx),%r12 */ + 0x48, 0x8b, 0x5a, 0x18, /* mov 0x18(%rdx),%rbx */ + 0x48, 0x8b, 0x6a, 0x10, /* mov 0x10(%rdx),%rbp */ + 0x48, 0x8b, 0x62, 0x08, /* mov 0x8(%rdx),%rsp */ + 0xff, 0x22, /* jmpq *(%rdx) */ + 0xc3, /* retq */ + 0x90, 0x90, 0x90, 0x90, 0x90, 0x90, /* nop */ +}; + +void (*_mco_wrap_main)(void) = (void(*)(void))(void*)_mco_wrap_main_code; +void (*_mco_switch)(_mco_ctxbuf* from, _mco_ctxbuf* to) = (void(*)(_mco_ctxbuf* from, _mco_ctxbuf* to))(void*)_mco_switch_code; + +static mco_result _mco_makectx(mco_coro* co, _mco_ctxbuf* ctx, void* stack_base, size_t stack_size) { + stack_size = stack_size - 32; /* Reserve 32 bytes for the shadow space. */ + void** stack_high_ptr = (void**)((size_t)stack_base + stack_size - sizeof(size_t)); + stack_high_ptr[0] = (void*)(0xdeaddeaddeaddead); /* Dummy return address. */ + ctx->rip = (void*)(_mco_wrap_main); + ctx->rsp = (void*)(stack_high_ptr); + ctx->r12 = (void*)(_mco_main); + ctx->r13 = (void*)(co); + void* stack_top = (void*)((size_t)stack_base + stack_size); + ctx->stack_base = stack_top; + ctx->stack_limit = stack_base; + ctx->dealloc_stack = stack_base; + return MCO_SUCCESS; +} + +#else /* not _WIN32 */ + +typedef struct _mco_ctxbuf { + void *rip, *rsp, *rbp, *rbx, *r12, *r13, *r14, *r15; +} _mco_ctxbuf; + +void _mco_wrap_main(void); +int _mco_switch(_mco_ctxbuf* from, _mco_ctxbuf* to); + +__asm__( + ".text\n" +#ifdef __MACH__ /* Mac OS X assembler */ + ".globl __mco_wrap_main\n" + "__mco_wrap_main:\n" +#else /* Linux assembler */ + ".globl _mco_wrap_main\n" + ".type _mco_wrap_main @function\n" + ".hidden _mco_wrap_main\n" + "_mco_wrap_main:\n" +#endif + " movq %r13, %rdi\n" + " jmpq *%r12\n" +#ifndef __MACH__ + ".size _mco_wrap_main, .-_mco_wrap_main\n" +#endif +); + +__asm__( + ".text\n" +#ifdef __MACH__ /* Mac OS assembler */ + ".globl __mco_switch\n" + "__mco_switch:\n" +#else /* Linux assembler */ + ".globl _mco_switch\n" + ".type _mco_switch @function\n" + ".hidden _mco_switch\n" + "_mco_switch:\n" +#endif + " leaq 0x3d(%rip), %rax\n" + " movq %rax, (%rdi)\n" + " movq %rsp, 8(%rdi)\n" + " movq %rbp, 16(%rdi)\n" + " movq %rbx, 24(%rdi)\n" + " movq %r12, 32(%rdi)\n" + " movq %r13, 40(%rdi)\n" + " movq %r14, 48(%rdi)\n" + " movq %r15, 56(%rdi)\n" + " movq 56(%rsi), %r15\n" + " movq 48(%rsi), %r14\n" + " movq 40(%rsi), %r13\n" + " movq 32(%rsi), %r12\n" + " movq 24(%rsi), %rbx\n" + " movq 16(%rsi), %rbp\n" + " movq 8(%rsi), %rsp\n" + " jmpq *(%rsi)\n" + " ret\n" +#ifndef __MACH__ + ".size _mco_switch, .-_mco_switch\n" +#endif +); + +static mco_result _mco_makectx(mco_coro* co, _mco_ctxbuf* ctx, void* stack_base, size_t stack_size) { + stack_size = stack_size - 128; /* Reserve 128 bytes for the Red Zone space (System V AMD64 ABI). */ + void** stack_high_ptr = (void**)((size_t)stack_base + stack_size - sizeof(size_t)); + stack_high_ptr[0] = (void*)(0xdeaddeaddeaddead); /* Dummy return address. */ + ctx->rip = (void*)(_mco_wrap_main); + ctx->rsp = (void*)(stack_high_ptr); + ctx->r12 = (void*)(_mco_main); + ctx->r13 = (void*)(co); + return MCO_SUCCESS; +} + +#endif /* not _WIN32 */ + +#elif defined(__riscv) + +typedef struct _mco_ctxbuf { + void* s[12]; /* s0-s11 */ + void* ra; + void* pc; + void* sp; +#ifdef __riscv_flen +#if __riscv_flen == 64 + double fs[12]; /* fs0-fs11 */ +#elif __riscv_flen == 32 + float fs[12]; /* fs0-fs11 */ +#endif +#endif /* __riscv_flen */ +} _mco_ctxbuf; + +void _mco_wrap_main(void); +int _mco_switch(_mco_ctxbuf* from, _mco_ctxbuf* to); + +__asm__( + ".text\n" + ".globl _mco_wrap_main\n" + ".type _mco_wrap_main @function\n" + ".hidden _mco_wrap_main\n" + "_mco_wrap_main:\n" + " mv a0, s0\n" + " jr s1\n" + ".size _mco_wrap_main, .-_mco_wrap_main\n" +); + +__asm__( + ".text\n" + ".globl _mco_switch\n" + ".type _mco_switch @function\n" + ".hidden _mco_switch\n" + "_mco_switch:\n" + #if __riscv_xlen == 64 + " sd s0, 0x00(a0)\n" + " sd s1, 0x08(a0)\n" + " sd s2, 0x10(a0)\n" + " sd s3, 0x18(a0)\n" + " sd s4, 0x20(a0)\n" + " sd s5, 0x28(a0)\n" + " sd s6, 0x30(a0)\n" + " sd s7, 0x38(a0)\n" + " sd s8, 0x40(a0)\n" + " sd s9, 0x48(a0)\n" + " sd s10, 0x50(a0)\n" + " sd s11, 0x58(a0)\n" + " sd ra, 0x60(a0)\n" + " sd ra, 0x68(a0)\n" /* pc */ + " sd sp, 0x70(a0)\n" + #ifdef __riscv_flen + #if __riscv_flen == 64 + " fsd fs0, 0x78(a0)\n" + " fsd fs1, 0x80(a0)\n" + " fsd fs2, 0x88(a0)\n" + " fsd fs3, 0x90(a0)\n" + " fsd fs4, 0x98(a0)\n" + " fsd fs5, 0xa0(a0)\n" + " fsd fs6, 0xa8(a0)\n" + " fsd fs7, 0xb0(a0)\n" + " fsd fs8, 0xb8(a0)\n" + " fsd fs9, 0xc0(a0)\n" + " fsd fs10, 0xc8(a0)\n" + " fsd fs11, 0xd0(a0)\n" + " fld fs0, 0x78(a1)\n" + " fld fs1, 0x80(a1)\n" + " fld fs2, 0x88(a1)\n" + " fld fs3, 0x90(a1)\n" + " fld fs4, 0x98(a1)\n" + " fld fs5, 0xa0(a1)\n" + " fld fs6, 0xa8(a1)\n" + " fld fs7, 0xb0(a1)\n" + " fld fs8, 0xb8(a1)\n" + " fld fs9, 0xc0(a1)\n" + " fld fs10, 0xc8(a1)\n" + " fld fs11, 0xd0(a1)\n" + #else + #error "Unsupported RISC-V FLEN" + #endif + #endif /* __riscv_flen */ + " ld s0, 0x00(a1)\n" + " ld s1, 0x08(a1)\n" + " ld s2, 0x10(a1)\n" + " ld s3, 0x18(a1)\n" + " ld s4, 0x20(a1)\n" + " ld s5, 0x28(a1)\n" + " ld s6, 0x30(a1)\n" + " ld s7, 0x38(a1)\n" + " ld s8, 0x40(a1)\n" + " ld s9, 0x48(a1)\n" + " ld s10, 0x50(a1)\n" + " ld s11, 0x58(a1)\n" + " ld ra, 0x60(a1)\n" + " ld a2, 0x68(a1)\n" /* pc */ + " ld sp, 0x70(a1)\n" + " jr a2\n" + #elif __riscv_xlen == 32 + " sw s0, 0x00(a0)\n" + " sw s1, 0x04(a0)\n" + " sw s2, 0x08(a0)\n" + " sw s3, 0x0c(a0)\n" + " sw s4, 0x10(a0)\n" + " sw s5, 0x14(a0)\n" + " sw s6, 0x18(a0)\n" + " sw s7, 0x1c(a0)\n" + " sw s8, 0x20(a0)\n" + " sw s9, 0x24(a0)\n" + " sw s10, 0x28(a0)\n" + " sw s11, 0x2c(a0)\n" + " sw ra, 0x30(a0)\n" + " sw ra, 0x34(a0)\n" /* pc */ + " sw sp, 0x38(a0)\n" + #ifdef __riscv_flen + #if __riscv_flen == 64 + " fsd fs0, 0x3c(a0)\n" + " fsd fs1, 0x44(a0)\n" + " fsd fs2, 0x4c(a0)\n" + " fsd fs3, 0x54(a0)\n" + " fsd fs4, 0x5c(a0)\n" + " fsd fs5, 0x64(a0)\n" + " fsd fs6, 0x6c(a0)\n" + " fsd fs7, 0x74(a0)\n" + " fsd fs8, 0x7c(a0)\n" + " fsd fs9, 0x84(a0)\n" + " fsd fs10, 0x8c(a0)\n" + " fsd fs11, 0x94(a0)\n" + " fld fs0, 0x3c(a1)\n" + " fld fs1, 0x44(a1)\n" + " fld fs2, 0x4c(a1)\n" + " fld fs3, 0x54(a1)\n" + " fld fs4, 0x5c(a1)\n" + " fld fs5, 0x64(a1)\n" + " fld fs6, 0x6c(a1)\n" + " fld fs7, 0x74(a1)\n" + " fld fs8, 0x7c(a1)\n" + " fld fs9, 0x84(a1)\n" + " fld fs10, 0x8c(a1)\n" + " fld fs11, 0x94(a1)\n" + #elif __riscv_flen == 32 + " fsw fs0, 0x3c(a0)\n" + " fsw fs1, 0x40(a0)\n" + " fsw fs2, 0x44(a0)\n" + " fsw fs3, 0x48(a0)\n" + " fsw fs4, 0x4c(a0)\n" + " fsw fs5, 0x50(a0)\n" + " fsw fs6, 0x54(a0)\n" + " fsw fs7, 0x58(a0)\n" + " fsw fs8, 0x5c(a0)\n" + " fsw fs9, 0x60(a0)\n" + " fsw fs10, 0x64(a0)\n" + " fsw fs11, 0x68(a0)\n" + " flw fs0, 0x3c(a1)\n" + " flw fs1, 0x40(a1)\n" + " flw fs2, 0x44(a1)\n" + " flw fs3, 0x48(a1)\n" + " flw fs4, 0x4c(a1)\n" + " flw fs5, 0x50(a1)\n" + " flw fs6, 0x54(a1)\n" + " flw fs7, 0x58(a1)\n" + " flw fs8, 0x5c(a1)\n" + " flw fs9, 0x60(a1)\n" + " flw fs10, 0x64(a1)\n" + " flw fs11, 0x68(a1)\n" + #else + #error "Unsupported RISC-V FLEN" + #endif + #endif /* __riscv_flen */ + " lw s0, 0x00(a1)\n" + " lw s1, 0x04(a1)\n" + " lw s2, 0x08(a1)\n" + " lw s3, 0x0c(a1)\n" + " lw s4, 0x10(a1)\n" + " lw s5, 0x14(a1)\n" + " lw s6, 0x18(a1)\n" + " lw s7, 0x1c(a1)\n" + " lw s8, 0x20(a1)\n" + " lw s9, 0x24(a1)\n" + " lw s10, 0x28(a1)\n" + " lw s11, 0x2c(a1)\n" + " lw ra, 0x30(a1)\n" + " lw a2, 0x34(a1)\n" /* pc */ + " lw sp, 0x38(a1)\n" + " jr a2\n" + #else + #error "Unsupported RISC-V XLEN" + #endif /* __riscv_xlen */ + ".size _mco_switch, .-_mco_switch\n" +); + +static mco_result _mco_makectx(mco_coro* co, _mco_ctxbuf* ctx, void* stack_base, size_t stack_size) { + ctx->s[0] = (void*)(co); + ctx->s[1] = (void*)(_mco_main); + ctx->pc = (void*)(_mco_wrap_main); +#if __riscv_xlen == 64 + ctx->ra = (void*)(0xdeaddeaddeaddead); +#elif __riscv_xlen == 32 + ctx->ra = (void*)(0xdeaddead); +#endif + ctx->sp = (void*)((size_t)stack_base + stack_size); + return MCO_SUCCESS; +} + +#elif defined(__i386) || defined(__i386__) + +typedef struct _mco_ctxbuf { + void *eip, *esp, *ebp, *ebx, *esi, *edi; +} _mco_ctxbuf; + +void _mco_switch(_mco_ctxbuf* from, _mco_ctxbuf* to); + +__asm__( + ".text\n" + ".globl _mco_switch\n" + ".type _mco_switch @function\n" + ".hidden _mco_switch\n" + "_mco_switch:\n" + " call 1f\n" + " 1:\n" + " popl %ecx\n" + " addl $(2f-1b), %ecx\n" + " movl 4(%esp), %eax\n" + " movl 8(%esp), %edx\n" + " movl %ecx, (%eax)\n" + " movl %esp, 4(%eax)\n" + " movl %ebp, 8(%eax)\n" + " movl %ebx, 12(%eax)\n" + " movl %esi, 16(%eax)\n" + " movl %edi, 20(%eax)\n" + " movl 20(%edx), %edi\n" + " movl 16(%edx), %esi\n" + " movl 12(%edx), %ebx\n" + " movl 8(%edx), %ebp\n" + " movl 4(%edx), %esp\n" + " jmp *(%edx)\n" + " 2:\n" + " ret\n" + ".size _mco_switch, .-_mco_switch\n" +); + +static mco_result _mco_makectx(mco_coro* co, _mco_ctxbuf* ctx, void* stack_base, size_t stack_size) { + void** stack_high_ptr = (void**)((size_t)stack_base + stack_size - 16 - 1*sizeof(size_t)); + stack_high_ptr[0] = (void*)(0xdeaddead); /* Dummy return address. */ + stack_high_ptr[1] = (void*)(co); + ctx->eip = (void*)(_mco_main); + ctx->esp = (void*)(stack_high_ptr); + return MCO_SUCCESS; +} + +#elif defined(__ARM_EABI__) + +typedef struct _mco_ctxbuf { +#ifndef __SOFTFP__ + void* f[16]; +#endif + void *d[4]; /* d8-d15 */ + void *r[4]; /* r4-r11 */ + void *lr; + void *sp; +} _mco_ctxbuf; + +void _mco_wrap_main(void); +int _mco_switch(_mco_ctxbuf* from, _mco_ctxbuf* to); + +__asm__( + ".text\n" + ".globl _mco_switch\n" + ".type _mco_switch #function\n" + ".hidden _mco_switch\n" + "_mco_switch:\n" +#ifndef __SOFTFP__ + " vstmia r0!, {d8-d15}\n" +#endif + " stmia r0, {r4-r11, lr}\n" + " str sp, [r0, #9*4]\n" +#ifndef __SOFTFP__ + " vldmia r1!, {d8-d15}\n" +#endif + " ldr sp, [r1, #9*4]\n" + " ldmia r1, {r4-r11, pc}\n" + ".size _mco_switch, .-_mco_switch\n" +); + +__asm__( + ".text\n" + ".globl _mco_wrap_main\n" + ".type _mco_wrap_main #function\n" + ".hidden _mco_wrap_main\n" + "_mco_wrap_main:\n" + " mov r0, r4\n" + " mov ip, r5\n" + " mov lr, r6\n" + " bx ip\n" + ".size _mco_wrap_main, .-_mco_wrap_main\n" +); + +static mco_result _mco_makectx(mco_coro* co, _mco_ctxbuf* ctx, void* stack_base, size_t stack_size) { + ctx->d[0] = (void*)(co); + ctx->d[1] = (void*)(_mco_main); + ctx->d[2] = (void*)(0xdeaddead); /* Dummy return address. */ + ctx->lr = (void*)(_mco_wrap_main); + ctx->sp = (void*)((size_t)stack_base + stack_size); + return MCO_SUCCESS; +} + +#elif defined(__aarch64__) + +typedef struct _mco_ctxbuf { + void *x[12]; /* x19-x30 */ + void *sp; + void *lr; + void *d[8]; /* d8-d15 */ +} _mco_ctxbuf; + +void _mco_wrap_main(void); +int _mco_switch(_mco_ctxbuf* from, _mco_ctxbuf* to); + +__asm__( + ".text\n" + ".globl _mco_switch\n" + ".type _mco_switch #function\n" + ".hidden _mco_switch\n" + "_mco_switch:\n" + " mov x10, sp\n" + " mov x11, x30\n" + " stp x19, x20, [x0, #(0*16)]\n" + " stp x21, x22, [x0, #(1*16)]\n" + " stp d8, d9, [x0, #(7*16)]\n" + " stp x23, x24, [x0, #(2*16)]\n" + " stp d10, d11, [x0, #(8*16)]\n" + " stp x25, x26, [x0, #(3*16)]\n" + " stp d12, d13, [x0, #(9*16)]\n" + " stp x27, x28, [x0, #(4*16)]\n" + " stp d14, d15, [x0, #(10*16)]\n" + " stp x29, x30, [x0, #(5*16)]\n" + " stp x10, x11, [x0, #(6*16)]\n" + " ldp x19, x20, [x1, #(0*16)]\n" + " ldp x21, x22, [x1, #(1*16)]\n" + " ldp d8, d9, [x1, #(7*16)]\n" + " ldp x23, x24, [x1, #(2*16)]\n" + " ldp d10, d11, [x1, #(8*16)]\n" + " ldp x25, x26, [x1, #(3*16)]\n" + " ldp d12, d13, [x1, #(9*16)]\n" + " ldp x27, x28, [x1, #(4*16)]\n" + " ldp d14, d15, [x1, #(10*16)]\n" + " ldp x29, x30, [x1, #(5*16)]\n" + " ldp x10, x11, [x1, #(6*16)]\n" + " mov sp, x10\n" + " br x11\n" + ".size _mco_switch, .-_mco_switch\n" +); + +__asm__( + ".text\n" + ".globl _mco_wrap_main\n" + ".type _mco_wrap_main #function\n" + ".hidden _mco_wrap_main\n" + "_mco_wrap_main:\n" + " mov x0, x19\n" + " mov x30, x21\n" + " br x20\n" + ".size _mco_wrap_main, .-_mco_wrap_main\n" +); + +static mco_result _mco_makectx(mco_coro* co, _mco_ctxbuf* ctx, void* stack_base, size_t stack_size) { + ctx->x[0] = (void*)(co); + ctx->x[1] = (void*)(_mco_main); + ctx->x[2] = (void*)(0xdeaddeaddeaddead); /* Dummy return address. */ + ctx->sp = (void*)((size_t)stack_base + stack_size); + ctx->lr = (void*)(_mco_wrap_main); + return MCO_SUCCESS; +} + +#else + +#error "Unsupported architecture for assembly method." + +#endif /* ARCH */ + +#elif defined(MCO_USE_UCONTEXT) + +#include + +typedef ucontext_t _mco_ctxbuf; + +#if defined(_LP64) || defined(__LP64__) +static void _mco_wrap_main(unsigned int lo, unsigned int hi) { + mco_coro* co = (mco_coro*)(((size_t)lo) | (((size_t)hi) << 32)); /* Extract coroutine pointer. */ + _mco_main(co); +} +#else +static void _mco_wrap_main(unsigned int lo) { + mco_coro* co = (mco_coro*)((size_t)lo); /* Extract coroutine pointer. */ + _mco_main(co); +} +#endif + +static MCO_FORCE_INLINE void _mco_switch(_mco_ctxbuf* from, _mco_ctxbuf* to) { + int res = swapcontext(from, to); + _MCO_UNUSED(res); + MCO_ASSERT(res == 0); +} + +static mco_result _mco_makectx(mco_coro* co, _mco_ctxbuf* ctx, void* stack_base, size_t stack_size) { + /* Initialize ucontext. */ + if(getcontext(ctx) != 0) { + MCO_LOG("failed to get ucontext"); + return MCO_MAKE_CONTEXT_ERROR; + } + ctx->uc_link = NULL; /* We never exit from _mco_wrap_main. */ + ctx->uc_stack.ss_sp = stack_base; + ctx->uc_stack.ss_size = stack_size; + unsigned int lo = (unsigned int)((size_t)co); +#if defined(_LP64) || defined(__LP64__) + unsigned int hi = (unsigned int)(((size_t)co)>>32); + makecontext(ctx, (void (*)(void))_mco_wrap_main, 2, lo, hi); +#else + makecontext(ctx, (void (*)(void))_mco_wrap_main, 1, lo); +#endif + return MCO_SUCCESS; +} + +#endif /* defined(MCO_USE_UCONTEXT) */ + +#ifdef MCO_USE_VALGRIND +#include +#endif + +typedef struct _mco_context { +#ifdef MCO_USE_VALGRIND + unsigned int valgrind_stack_id; +#endif + _mco_ctxbuf ctx; + _mco_ctxbuf back_ctx; +} _mco_context; + +static void _mco_jumpin(mco_coro* co) { + _mco_context* context = (_mco_context*)co->context; + _mco_prepare_jumpin(co); + _mco_switch(&context->back_ctx, &context->ctx); /* Do the context switch. */ +} + +static void _mco_jumpout(mco_coro* co) { + _mco_context* context = (_mco_context*)co->context; + _mco_prepare_jumpout(co); + _mco_switch(&context->ctx, &context->back_ctx); /* Do the context switch. */ +} + +static mco_result _mco_create_context(mco_coro* co, mco_desc* desc) { + /* Determine the context and stack address. */ + size_t co_addr = (size_t)co; + size_t context_addr = _mco_align_forward(co_addr + sizeof(mco_coro), 16); + size_t storage_addr = _mco_align_forward(context_addr + sizeof(_mco_context), 16); + size_t stack_addr = _mco_align_forward(storage_addr + desc->storage_size, 16); + /* Initialize context. */ + _mco_context* context = (_mco_context*)context_addr; + memset(context, 0, sizeof(_mco_context)); + /* Initialize storage. */ + unsigned char* storage = (unsigned char*)storage_addr; + memset(storage, 0, desc->storage_size); + /* Initialize stack. */ + void *stack_base = (void*)stack_addr; + size_t stack_size = desc->stack_size; +#ifdef MCO_ZERO_MEMORY + memset(stack_base, 0, stack_size); +#endif + /* Make the context. */ + mco_result res = _mco_makectx(co, &context->ctx, stack_base, stack_size); + if(res != MCO_SUCCESS) { + return res; + } +#ifdef MCO_USE_VALGRIND + context->valgrind_stack_id = VALGRIND_STACK_REGISTER(stack_addr, stack_addr + stack_size); +#endif + co->context = context; + co->stack_base = stack_base; + co->stack_size = stack_size; + co->storage = storage; + co->storage_size = desc->storage_size; + return MCO_SUCCESS; +} + +static void _mco_destroy_context(mco_coro* co) { +#ifdef MCO_USE_VALGRIND + _mco_context* context = (_mco_context*)co->context; + if(context && context->valgrind_stack_id != 0) { + VALGRIND_STACK_DEREGISTER(context->valgrind_stack_id); + context->valgrind_stack_id = 0; + } +#else + _MCO_UNUSED(co); +#endif +} + +static MCO_FORCE_INLINE void _mco_init_desc_sizes(mco_desc* desc, size_t stack_size) { + desc->coro_size = _mco_align_forward(sizeof(mco_coro), 16) + + _mco_align_forward(sizeof(_mco_context), 16) + + _mco_align_forward(desc->storage_size, 16) + + stack_size + 16; + desc->stack_size = stack_size; /* This is just a hint, it won't be the real one. */ +} + +#endif /* defined(MCO_USE_UCONTEXT) || defined(MCO_USE_ASM) */ + +/* ---------------------------------------------------------------------------------------------- */ + +#ifdef MCO_USE_FIBERS + +#ifdef _WIN32 + +#ifndef _WIN32_WINNT +#define _WIN32_WINNT 0x0400 +#endif +#include + +typedef struct _mco_context { + void* fib; + void* back_fib; +} _mco_context; + +static void _mco_jumpin(mco_coro* co) { + void *cur_fib = GetCurrentFiber(); + if(!cur_fib || cur_fib == (void*)0x1e00) { /* See http://blogs.msdn.com/oldnewthing/archive/2004/12/31/344799.aspx */ + cur_fib = ConvertThreadToFiber(NULL); + } + MCO_ASSERT(cur_fib != NULL); + _mco_context* context = (_mco_context*)co->context; + context->back_fib = cur_fib; + _mco_prepare_jumpin(co); + SwitchToFiber(context->fib); +} + +static void CALLBACK _mco_wrap_main(void* co) { + _mco_main((mco_coro*)co); +} + +static void _mco_jumpout(mco_coro* co) { + _mco_context* context = (_mco_context*)co->context; + void* back_fib = context->back_fib; + MCO_ASSERT(back_fib != NULL); + context->back_fib = NULL; + _mco_prepare_jumpout(co); + SwitchToFiber(back_fib); +} + +/* Reverse engineered Fiber struct, used to get stack base. */ +typedef struct _mco_fiber { + LPVOID param; /* fiber param */ + void* except; /* saved exception handlers list */ + void* stack_base; /* top of fiber stack */ + void* stack_limit; /* fiber stack low-water mark */ + void* stack_allocation; /* base of the fiber stack allocation */ + CONTEXT context; /* fiber context */ + DWORD flags; /* fiber flags */ + LPFIBER_START_ROUTINE start; /* start routine */ + void **fls_slots; /* fiber storage slots */ +} _mco_fiber; + +static mco_result _mco_create_context(mco_coro* co, mco_desc* desc) { + /* Determine the context address. */ + size_t co_addr = (size_t)co; + size_t context_addr = _mco_align_forward(co_addr + sizeof(mco_coro), 16); + size_t storage_addr = _mco_align_forward(context_addr + sizeof(_mco_context), 16); + /* Initialize context. */ + _mco_context* context = (_mco_context*)context_addr; + memset(context, 0, sizeof(_mco_context)); + /* Initialize storage. */ + unsigned char* storage = (unsigned char*)storage_addr; + memset(storage, 0, desc->storage_size); + /* Create the fiber. */ + _mco_fiber* fib = (_mco_fiber*)CreateFiberEx(desc->stack_size, desc->stack_size, FIBER_FLAG_FLOAT_SWITCH, _mco_wrap_main, co); + if(!fib) { + MCO_LOG("failed to create fiber"); + return MCO_MAKE_CONTEXT_ERROR; + } + context->fib = fib; + co->context = context; + co->stack_base = fib->stack_base; + co->stack_size = desc->stack_size; + co->storage = storage; + co->storage_size = desc->storage_size; + return MCO_SUCCESS; +} + +static void _mco_destroy_context(mco_coro* co) { + _mco_context* context = (_mco_context*)co->context; + if(context && context->fib) { + DeleteFiber(context->fib); + context->fib = NULL; + } +} + +static MCO_FORCE_INLINE void _mco_init_desc_sizes(mco_desc* desc, size_t stack_size) { + desc->coro_size = _mco_align_forward(sizeof(mco_coro), 16) + + _mco_align_forward(sizeof(_mco_context), 16) + + _mco_align_forward(desc->storage_size, 16) + + 16; + desc->stack_size = stack_size; +} + +#elif defined(__EMSCRIPTEN__) + +#include + +#ifndef MCO_ASYNCFY_STACK_SIZE +#define MCO_ASYNCFY_STACK_SIZE 16384 +#endif + +typedef struct _mco_context { + emscripten_fiber_t fib; + emscripten_fiber_t* back_fib; +} _mco_context; + +static emscripten_fiber_t* running_fib = NULL; +static unsigned char main_asyncify_stack[MCO_ASYNCFY_STACK_SIZE]; +static emscripten_fiber_t main_fib; + +static void _mco_wrap_main(void* co) { + _mco_main((mco_coro*)co); +} + +static void _mco_jumpin(mco_coro* co) { + _mco_context* context = (_mco_context*)co->context; + emscripten_fiber_t* back_fib = running_fib; + if(!back_fib) { + back_fib = &main_fib; + emscripten_fiber_init_from_current_context(back_fib, main_asyncify_stack, MCO_ASYNCFY_STACK_SIZE); + } + running_fib = &context->fib; + context->back_fib = back_fib; + _mco_prepare_jumpin(co); + emscripten_fiber_swap(back_fib, &context->fib); /* Do the context switch. */ +} + +static void _mco_jumpout(mco_coro* co) { + _mco_context* context = (_mco_context*)co->context; + running_fib = context->back_fib; + _mco_prepare_jumpout(co); + emscripten_fiber_swap(&context->fib, context->back_fib); /* Do the context switch. */ +} + +static mco_result _mco_create_context(mco_coro* co, mco_desc* desc) { + if(emscripten_has_asyncify() != 1) { + MCO_LOG("failed to create fiber because ASYNCIFY is not enabled"); + return MCO_MAKE_CONTEXT_ERROR; + } + /* Determine the context address. */ + size_t co_addr = (size_t)co; + size_t context_addr = _mco_align_forward(co_addr + sizeof(mco_coro), 16); + size_t storage_addr = _mco_align_forward(context_addr + sizeof(_mco_context), 16); + size_t stack_addr = _mco_align_forward(storage_addr + desc->storage_size, 16); + size_t asyncify_stack_addr = _mco_align_forward(stack_addr + desc->stack_size, 16); + /* Initialize context. */ + _mco_context* context = (_mco_context*)context_addr; + memset(context, 0, sizeof(_mco_context)); + /* Initialize storage. */ + unsigned char* storage = (unsigned char*)storage_addr; + memset(storage, 0, desc->storage_size); + /* Initialize stack. */ + void *stack_base = (void*)stack_addr; + size_t stack_size = asyncify_stack_addr - stack_addr; + void *asyncify_stack_base = (void*)asyncify_stack_addr; + size_t asyncify_stack_size = co_addr + desc->coro_size - asyncify_stack_addr; +#ifdef MCO_ZERO_MEMORY + memset(stack_base, 0, stack_size); + memset(asyncify_stack_base, 0, asyncify_stack_size); +#endif + /* Create the fiber. */ + emscripten_fiber_init(&context->fib, _mco_wrap_main, co, stack_base, stack_size, asyncify_stack_base, asyncify_stack_size); + co->context = context; + co->stack_base = stack_base; + co->stack_size = stack_size; + co->storage = storage; + co->storage_size = desc->storage_size; + return MCO_SUCCESS; +} + +static void _mco_destroy_context(mco_coro* co) { + /* Nothing to do. */ + _MCO_UNUSED(co); +} + +static MCO_FORCE_INLINE void _mco_init_desc_sizes(mco_desc* desc, size_t stack_size) { + desc->coro_size = _mco_align_forward(sizeof(mco_coro), 16) + + _mco_align_forward(sizeof(_mco_context), 16) + + _mco_align_forward(desc->storage_size, 16) + + _mco_align_forward(stack_size, 16) + + _mco_align_forward(MCO_ASYNCFY_STACK_SIZE, 16) + + 16; + desc->stack_size = stack_size; /* This is just a hint, it won't be the real one. */ +} + +#else + +#error "Unsupported architecture for fibers method." + +#endif + +#endif /* MCO_USE_FIBERS */ + +/* ---------------------------------------------------------------------------------------------- */ + +mco_desc mco_desc_init(void (*func)(mco_coro* co), size_t stack_size) { + if(stack_size != 0) { + /* Stack size should be at least `MCO_MIN_STACK_SIZE`. */ + if(stack_size < MCO_MIN_STACK_SIZE) { + stack_size = MCO_MIN_STACK_SIZE; + } + } else { + stack_size = MCO_DEFAULT_STACK_SIZE; + } + stack_size = _mco_align_forward(stack_size, 16); /* Stack size should be aligned to 16 bytes. */ + mco_desc desc; + memset(&desc, 0, sizeof(mco_desc)); +#ifndef MCO_NO_DEFAULT_ALLOCATORS + /* Set default allocators. */ + desc.malloc_cb = mco_malloc; + desc.free_cb = mco_free; +#endif + desc.func = func; + desc.storage_size = MCO_DEFAULT_STORAGE_SIZE; + _mco_init_desc_sizes(&desc, stack_size); + return desc; +} + +static mco_result _mco_validate_desc(mco_desc* desc) { + if(!desc) { + MCO_LOG("coroutine description is NULL"); + return MCO_INVALID_ARGUMENTS; + } + if(!desc->func) { + MCO_LOG("coroutine function in invalid"); + return MCO_INVALID_ARGUMENTS; + } + if(desc->stack_size < MCO_MIN_STACK_SIZE) { + MCO_LOG("coroutine stack size is too small"); + return MCO_INVALID_ARGUMENTS; + } + if(desc->coro_size < sizeof(mco_coro)) { + MCO_LOG("coroutine size is invalid"); + return MCO_INVALID_ARGUMENTS; + } + return MCO_SUCCESS; +} + +mco_result mco_init(mco_coro* co, mco_desc* desc) { + if(!co) { + MCO_LOG("attempt to initialize an invalid coroutine"); + return MCO_INVALID_COROUTINE; + } + memset(co, 0, sizeof(mco_coro)); + /* Validate coroutine description. */ + mco_result res = _mco_validate_desc(desc); + if(res != MCO_SUCCESS) + return res; + /* Create the coroutine. */ + res = _mco_create_context(co, desc); + if(res != MCO_SUCCESS) + return res; + co->state = MCO_SUSPENDED; /* We initialize in suspended state. */ + co->free_cb = desc->free_cb; + co->allocator_data = desc->allocator_data; + co->func = desc->func; + co->user_data = desc->user_data; +#ifdef _MCO_USE_TSAN + co->tsan_fiber = __tsan_create_fiber(0); +#endif + return MCO_SUCCESS; +} + +mco_result mco_uninit(mco_coro* co) { + if(!co) { + MCO_LOG("attempt to uninitialize an invalid coroutine"); + return MCO_INVALID_COROUTINE; + } + /* Cannot uninitialize while running. */ + if(!(co->state == MCO_SUSPENDED || co->state == MCO_DEAD)) { + MCO_LOG("attempt to uninitialize a coroutine that is not dead or suspended"); + return MCO_INVALID_OPERATION; + } + /* The coroutine is now dead and cannot be used anymore. */ + co->state = MCO_DEAD; +#ifdef _MCO_USE_TSAN + if(co->tsan_fiber != NULL) { + __tsan_destroy_fiber(co->tsan_fiber); + co->tsan_fiber = NULL; + } +#endif + _mco_destroy_context(co); + return MCO_SUCCESS; +} + +mco_result mco_create(mco_coro** out_co, mco_desc* desc) { + /* Validate input. */ + if(!out_co) { + MCO_LOG("coroutine output pointer is NULL"); + return MCO_INVALID_POINTER; + } + if(!desc || !desc->malloc_cb || !desc->free_cb) { + *out_co = NULL; + MCO_LOG("coroutine allocator description is not set"); + return MCO_INVALID_ARGUMENTS; + } + /* Allocate the coroutine. */ + mco_coro* co = (mco_coro*)desc->malloc_cb(desc->coro_size, desc->allocator_data); + if(!co) { + MCO_LOG("coroutine allocation failed"); + *out_co = NULL; + return MCO_OUT_OF_MEMORY; + } + /* Initialize the coroutine. */ + mco_result res = mco_init(co, desc); + if(res != MCO_SUCCESS) { + desc->free_cb(co, desc->allocator_data); + *out_co = NULL; + return res; + } + *out_co = co; + return MCO_SUCCESS; +} + +mco_result mco_destroy(mco_coro* co) { + if(!co) { + MCO_LOG("attempt to destroy an invalid coroutine"); + return MCO_INVALID_COROUTINE; + } + /* Uninitialize the coroutine first. */ + mco_result res = mco_uninit(co); + if(res != MCO_SUCCESS) + return res; + /* Free the coroutine. */ + if(!co->free_cb) { + MCO_LOG("attempt destroy a coroutine that has no free callback"); + return MCO_INVALID_POINTER; + } + co->free_cb(co, co->allocator_data); + return MCO_SUCCESS; +} + +mco_result mco_resume(mco_coro* co) { + if(!co) { + MCO_LOG("attempt to resume an invalid coroutine"); + return MCO_INVALID_COROUTINE; + } + if(co->state != MCO_SUSPENDED) { /* Can only resume coroutines that are suspended. */ + MCO_LOG("attempt to resume a coroutine that is not suspended"); + return MCO_NOT_SUSPENDED; + } + co->state = MCO_RUNNING; /* The coroutine is now running. */ + _mco_jumpin(co); + return MCO_SUCCESS; +} + +mco_result mco_yield(mco_coro* co) { + if(!co) { + MCO_LOG("attempt to yield an invalid coroutine"); + return MCO_INVALID_COROUTINE; + } + if(co->state != MCO_RUNNING) { /* Can only yield coroutines that are running. */ + MCO_LOG("attempt to yield a coroutine that is not running"); + return MCO_NOT_RUNNING; + } + co->state = MCO_SUSPENDED; /* The coroutine is now suspended. */ + _mco_jumpout(co); + return MCO_SUCCESS; +} + +mco_state mco_status(mco_coro* co) { + if(co != NULL) { + return co->state; + } + return MCO_DEAD; +} + +void* mco_get_user_data(mco_coro* co) { + if(co != NULL) { + return co->user_data; + } + return NULL; +} + +mco_result mco_push(mco_coro* co, const void* src, size_t len) { + if(!co) { + MCO_LOG("attempt to use an invalid coroutine"); + return MCO_INVALID_COROUTINE; + } else if(len > 0) { + size_t bytes_stored = co->bytes_stored + len; + if(bytes_stored > co->storage_size) { + MCO_LOG("attempt to push bytes too many bytes into coroutine storage"); + return MCO_NOT_ENOUGH_SPACE; + } + if(!src) { + MCO_LOG("attempt push a null pointer into coroutine storage"); + return MCO_INVALID_POINTER; + } + memcpy(&co->storage[co->bytes_stored], src, len); + co->bytes_stored = bytes_stored; + } + return MCO_SUCCESS; +} + +mco_result mco_pop(mco_coro* co, void* dest, size_t len) { + if(!co) { + MCO_LOG("attempt to use an invalid coroutine"); + return MCO_INVALID_COROUTINE; + } else if(len > 0) { + if(len > co->bytes_stored) { + MCO_LOG("attempt to pop too many bytes from coroutine storage"); + return MCO_NOT_ENOUGH_SPACE; + } + size_t bytes_stored = co->bytes_stored - len; + if(dest) { + memcpy(dest, &co->storage[bytes_stored], len); + } + co->bytes_stored = bytes_stored; +#ifdef MCO_ZERO_MEMORY + /* Clear garbage in the discarded storage. */ + memset(&co->storage[bytes_stored], 0, len); +#endif + } + return MCO_SUCCESS; +} + +mco_result mco_peek(mco_coro* co, void* dest, size_t len) { + if(!co) { + MCO_LOG("attempt to use an invalid coroutine"); + return MCO_INVALID_COROUTINE; + } else if(len > 0) { + if(len > co->bytes_stored) { + MCO_LOG("attempt to peek too many bytes from coroutine storage"); + return MCO_NOT_ENOUGH_SPACE; + } + if(!dest) { + MCO_LOG("attempt peek into a null pointer"); + return MCO_INVALID_POINTER; + } + memcpy(dest, &co->storage[co->bytes_stored - len], len); + } + return MCO_SUCCESS; +} + +size_t mco_get_bytes_stored(mco_coro* co) { + if(co == NULL) { + return 0; + } + return co->bytes_stored; +} + +size_t mco_get_storage_size(mco_coro* co) { + if(co == NULL) { + return 0; + } + return co->storage_size; +} + +#ifdef MCO_NO_MULTITHREAD +mco_coro* mco_running(void) { + return mco_current_co; +} +#else +static mco_coro* _mco_running(void) { + return mco_current_co; +} +mco_coro* mco_running(void) { + /* + Compilers aggressively optimize the use of TLS by caching loads. + Since fiber code can migrate between threads it’s possible for the load to be stale. + To prevent this from happening we avoid inline functions. + */ + mco_coro* (*volatile func)(void) = _mco_running; + return func(); +} +#endif + +const char* mco_result_description(mco_result res) { + switch(res) { + case MCO_SUCCESS: + return "No error"; + case MCO_GENERIC_ERROR: + return "Generic error"; + case MCO_INVALID_POINTER: + return "Invalid pointer"; + case MCO_INVALID_COROUTINE: + return "Invalid coroutine"; + case MCO_NOT_SUSPENDED: + return "Coroutine not suspended"; + case MCO_NOT_RUNNING: + return "Coroutine not running"; + case MCO_MAKE_CONTEXT_ERROR: + return "Make context error"; + case MCO_SWITCH_CONTEXT_ERROR: + return "Switch context error"; + case MCO_NOT_ENOUGH_SPACE: + return "Not enough space"; + case MCO_OUT_OF_MEMORY: + return "Out of memory"; + case MCO_INVALID_ARGUMENTS: + return "Invalid arguments"; + case MCO_INVALID_OPERATION: + return "Invalid operation"; + default: + return "Unknown error"; + } +} + +#ifdef __cplusplus +} +#endif + +#endif /* MINICORO_IMPL */ + +/* +This software is available as a choice of the following licenses. Choose +whichever you prefer. + +=============================================================================== +ALTERNATIVE 1 - Public Domain (www.unlicense.org) +=============================================================================== +This is free and unencumbered software released into the public domain. + +Anyone is free to copy, modify, publish, use, compile, sell, or distribute this +software, either in source code form or as a compiled binary, for any purpose, +commercial or non-commercial, and by any means. + +In jurisdictions that recognize copyright laws, the author or authors of this +software dedicate any and all copyright interest in the software to the public +domain. We make this dedication for the benefit of the public at large and to +the detriment of our heirs and successors. We intend this dedication to be an +overt act of relinquishment in perpetuity of all present and future rights to +this software under copyright law. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN +ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION +WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +For more information, please refer to + +=============================================================================== +ALTERNATIVE 2 - MIT No Attribution +=============================================================================== +Copyright (c) 2021 Eduardo Bart (https://github.com/edubart/minicoro) + +Permission is hereby granted, free of charge, to any person obtaining a copy of +this software and associated documentation files (the "Software"), to deal in +the Software without restriction, including without limitation the rights to +use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies +of the Software, and to permit persons to whom the Software is furnished to do +so. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. +*/ \ No newline at end of file diff --git a/framework/3rd/readme.txt b/framework/3rd/readme.txt new file mode 100644 index 0000000..e69de29 diff --git a/framework/lualib/thirdparty/bint/bint.lua b/framework/lualib/thirdparty/bint/bint.lua new file mode 100644 index 0000000..a9ca81f --- /dev/null +++ b/framework/lualib/thirdparty/bint/bint.lua @@ -0,0 +1,1632 @@ +--[[-- +lua-bint - v0.4.1 - 28/Jan/2022 +Eduardo Bart - edub4rt@gmail.com +https://github.com/edubart/lua-bint + +Small portable arbitrary-precision integer arithmetic library in pure Lua for +computing with large integers. + +Different from most arbitrary-precision integer libraries in pure Lua out there this one +uses an array of lua integers as underlying data-type in its implementation instead of +using strings or large tables, this make it efficient for working with fixed width integers +and to make bitwise operations. + +## Design goals + +The main design goal of this library is to be small, correct, self contained and use few +resources while retaining acceptable performance and feature completeness. + +The library is designed to follow recent Lua integer semantics, this means that +integer overflow warps around, +signed integers are implemented using two-complement arithmetic rules, +integer division operations rounds towards minus infinity, +any mixed operations with float numbers promotes the value to a float, +and the usual division/power operation always promotes to floats. + +The library is designed to be possible to work with only unsigned integer arithmetic +when using the proper methods. + +All the lua arithmetic operators (+, -, *, //, /, %) and bitwise operators (&, |, ~, <<, >>) +are implemented as metamethods. + +The integer size must be fixed in advance and the library is designed to be more efficient when +working with integers of sizes between 64-4096 bits. If you need to work with really huge numbers +without size restrictions then use another library. This choice has been made to have more efficiency +in that specific size range. + +## Usage + +First on you should require the bint file including how many bits the bint module will work with, +by calling the returned function from the require, for example: + +```lua +local bint = require 'bint'(1024) +``` + +For more information about its arguments see @{newmodule}. +Then when you need create a bint, you can use one of the following functions: + +* @{bint.fromuinteger} (convert from lua integers, but read as unsigned integer) +* @{bint.frominteger} (convert from lua integers, preserving the sign) +* @{bint.frombase} (convert from arbitrary bases, like hexadecimal) +* @{bint.trunc} (convert from lua numbers, truncating the fractional part) +* @{bint.new} (convert from anything, asserts on invalid integers) +* @{bint.tobint} (convert from anything, returns nil on invalid integers) +* @{bint.parse} (convert from anything, returns a lua number as fallback) +* @{bint.zero} +* @{bint.one} +* `bint` + +You can also call `bint` as it is an alias to `bint.new`. +In doubt use @{bint.new} to create a new bint. + +Then you can use all the usual lua numeric operations on it, +all the arithmetic metamethods are implemented. +When you are done computing and need to get the result, +get the output from one of the following functions: + +* @{bint.touinteger} (convert to a lua integer, wraps around as an unsigned integer) +* @{bint.tointeger} (convert to a lua integer, wraps around, preserves the sign) +* @{bint.tonumber} (convert to lua float, losing precision) +* @{bint.tobase} (convert to a string in any base) +* @{bint.__tostring} (convert to a string in base 10) + +To output a very large integer with no loss you probably want to use @{bint.tobase} +or call `tostring` to get a string representation. + +## Precautions + +All library functions can be mixed with lua numbers, +this makes easy to mix operations between bints and lua numbers, +however the user should take care in some situations: + +* Don't mix integers and float operations if you want to work with integers only. +* Don't use the regular equal operator ('==') to compare values from this library, +unless you know in advance that both values are of the same primitive type, +otherwise it will always return false, use @{bint.eq} to be safe. +* Don't pass fractional numbers to functions that an integer is expected +* Don't mix operations between bint classes with different sizes as this is not supported, this +will throw assertions. +* Remember that casting back to lua integers or numbers precision can be lost. +* For dividing while preserving integers use the @{bint.__idiv} (the '//' operator). +* For doing power operation preserving integers use the @{bint.ipow} function. +* Configure the proper integer size you intend to work with, otherwise large integers may wrap around. + +]] -- Returns number of bits of the internal lua integer type. +local function luainteger_bitsize() + local n, i = -1, 0 + repeat + n, i = n >> 16, i + 16 + until n == 0 + return i +end + +local math_type = math.type +local math_floor = math.floor +local math_abs = math.abs +local math_ceil = math.ceil +local math_modf = math.modf +local math_maxinteger = math.maxinteger +local math_max = math.max +local math_min = math.min +local string_format = string.format +local table_insert = table.insert +local table_concat = table.concat + +local memo = {} + +--- Create a new bint module representing integers of the desired bit size. +-- This is the returned function when `require 'bint'` is called. +-- @function newmodule +-- @param bits Number of bits for the integer representation, must be multiple of wordbits and +-- at least 64. +-- @param[opt] wordbits Number of the bits for the internal word, +-- defaults to half of Lua's integer size. +local function newmodule(bits, wordbits) + + local intbits = luainteger_bitsize() + bits = bits or 256 + wordbits = wordbits or (intbits // 2) + + -- Memoize bint modules + local memoindex = bits * 64 + wordbits + if memo[memoindex] then + return memo[memoindex] + end + + -- Validate + assert(bits % wordbits == 0, 'bitsize is not multiple of word bitsize') + assert(2 * wordbits <= intbits, 'word bitsize must be half of the lua integer bitsize') + assert(bits >= 64, 'bitsize must be >= 64') + + -- Create bint module + local bint = {} + bint.__index = bint + + --- Number of bits representing a bint instance. + bint.bits = bits + + -- Constants used internally + local BINT_BITS = bits + local BINT_WORDBITS = wordbits + local BINT_SIZE = BINT_BITS // BINT_WORDBITS + local BINT_WORDMAX = (1 << BINT_WORDBITS) - 1 + local BINT_WORDMSB = (1 << (BINT_WORDBITS - 1)) + local BINT_MATHMININTEGER, BINT_MATHMAXINTEGER + local BINT_MININTEGER + + --- Create a new bint with 0 value. + function bint.zero() + local x = setmetatable({}, bint) + for i = 1, BINT_SIZE do + x[i] = 0 + end + return x + end + local bint_zero = bint.zero + + --- Create a new bint with 1 value. + function bint.one() + local x = setmetatable({}, bint) + x[1] = 1 + for i = 2, BINT_SIZE do + x[i] = 0 + end + return x + end + local bint_one = bint.one + + -- Convert a value to a lua integer without losing precision. + local function tointeger(x) + x = tonumber(x) + local ty = math_type(x) + if ty == 'float' then + local floorx = math_floor(x) + if floorx == x then + x = floorx + ty = math_type(x) + end + end + if ty == 'integer' then + return x + end + end + + --- Create a bint from an unsigned integer. + -- Treats signed integers as an unsigned integer. + -- @param x A value to initialize from convertible to a lua integer. + -- @return A new bint or nil in case the input cannot be represented by an integer. + -- @see bint.frominteger + function bint.fromuinteger(x) + x = tointeger(x) + if x then + if x == 1 then + return bint_one() + elseif x == 0 then + return bint_zero() + end + local n = setmetatable({}, bint) + for i = 1, BINT_SIZE do + n[i] = x & BINT_WORDMAX + x = x >> BINT_WORDBITS + end + return n + end + end + local bint_fromuinteger = bint.fromuinteger + + --- Create a bint from a signed integer. + -- @param x A value to initialize from convertible to a lua integer. + -- @return A new bint or nil in case the input cannot be represented by an integer. + -- @see bint.fromuinteger + function bint.frominteger(x) + x = tointeger(x) + if x then + if x == 1 then + return bint_one() + elseif x == 0 then + return bint_zero() + end + local neg = false + if x < 0 then + x = math_abs(x) + neg = true + end + local n = setmetatable({}, bint) + for i = 1, BINT_SIZE do + n[i] = x & BINT_WORDMAX + x = x >> BINT_WORDBITS + end + if neg then + n:_unm() + end + return n + end + end + local bint_frominteger = bint.frominteger + + local basesteps = {} + + -- Compute the read step for frombase function + local function getbasestep(base) + local step = basesteps[base] + if step then + return step + end + step = 0 + local dmax = 1 + local limit = math_maxinteger // base + repeat + step = step + 1 + dmax = dmax * base + until dmax >= limit + basesteps[base] = step + return step + end + + -- Compute power with lua integers. + local function ipow(y, x, n) + if n == 1 then + return y * x + elseif n & 1 == 0 then -- even + return ipow(y, x * x, n // 2) + end + return ipow(x * y, x * x, (n - 1) // 2) + end + + --- Create a bint from a string of the desired base. + -- @param s The string to be converted from, + -- must have only alphanumeric and '+-' characters. + -- @param[opt] base Base that the number is represented, defaults to 10. + -- Must be at least 2 and at most 36. + -- @return A new bint or nil in case the conversion failed. + function bint.frombase(s, base) + if type(s) ~= 'string' then + return + end + base = base or 10 + if not (base >= 2 and base <= 36) then + -- number base is too large + return + end + local step = getbasestep(base) + if #s < step then + -- string is small, use tonumber (faster) + return bint_frominteger(tonumber(s, base)) + end + local sign, int = s:lower():match('^([+-]?)(%w+)$') + if not (sign and int) then + -- invalid integer string representation + return + end + local n = bint_zero() + for i = 1, #int, step do + local part = int:sub(i, i + step - 1) + local d = tonumber(part, base) + if not d then + -- invalid integer string representation + return + end + if i > 1 then + n = n * ipow(1, base, #part) + end + if d ~= 0 then + n:_add(d) + end + end + if sign == '-' then + n:_unm() + end + return n + end + local bint_frombase = bint.frombase + + --- Create a new bint from a value. + -- @param x A value convertible to a bint (string, number or another bint). + -- @return A new bint, guaranteed to be a new reference in case needed. + -- @raise An assert is thrown in case x is not convertible to a bint. + -- @see bint.tobint + -- @see bint.parse + function bint.new(x) + if getmetatable(x) ~= bint then + local ty = type(x) + if ty == 'number' then + return bint_frominteger(x) + elseif ty == 'string' then + return bint_frombase(x, 10) + end + error('value cannot be represented by a bint') + end + -- return a clone + local n = setmetatable({}, bint) + for i = 1, BINT_SIZE do + n[i] = x[i] + end + return n + end + local bint_new = bint.new + + --- Convert a value to a bint if possible. + -- @param x A value to be converted (string, number or another bint). + -- @param[opt] clone A boolean that tells if a new bint reference should be returned. + -- Defaults to false. + -- @return A bint or nil in case the conversion failed. + -- @see bint.new + -- @see bint.parse + function bint.tobint(x, clone) + if getmetatable(x) == bint then + if not clone then + return x + end + -- return a clone + local n = setmetatable({}, bint) + for i = 1, BINT_SIZE do + n[i] = x[i] + end + return n + end + local ty = type(x) + if ty == 'number' then + return bint_frominteger(x) + elseif ty == 'string' then + return bint_frombase(x, 10) + end + end + local tobint = bint.tobint + + --- Convert a value to a bint if possible otherwise to a lua number. + -- Useful to prepare values that you are unsure if it's going to be an integer or float. + -- @param x A value to be converted (string, number or another bint). + -- @param[opt] clone A boolean that tells if a new bint reference should be returned. + -- Defaults to false. + -- @return A bint or a lua number or nil in case the conversion failed. + -- @see bint.new + -- @see bint.tobint + function bint.parse(x, clone) + local i = tobint(x, clone) + if i then + return i + end + return tonumber(x) + end + local bint_parse = bint.parse + + --- Convert a bint to an unsigned integer. + -- Note that large unsigned integers may be represented as negatives in lua integers. + -- Note that lua cannot represent values larger than 64 bits, + -- in that case integer values wrap around. + -- @param x A bint or a number to be converted into an unsigned integer. + -- @return An integer or nil in case the input cannot be represented by an integer. + -- @see bint.tointeger + function bint.touinteger(x) + if getmetatable(x) == bint then + local n = 0 + for i = 1, BINT_SIZE do + n = n | (x[i] << (BINT_WORDBITS * (i - 1))) + end + return n + end + return tointeger(x) + end + + --- Convert a bint to a signed integer. + -- It works by taking absolute values then applying the sign bit in case needed. + -- Note that lua cannot represent values larger than 64 bits, + -- in that case integer values wrap around. + -- @param x A bint or value to be converted into an unsigned integer. + -- @return An integer or nil in case the input cannot be represented by an integer. + -- @see bint.touinteger + function bint.tointeger(x) + if getmetatable(x) == bint then + local n = 0 + local neg = x:isneg() + if neg then + x = -x + end + for i = 1, BINT_SIZE do + n = n | (x[i] << (BINT_WORDBITS * (i - 1))) + end + if neg then + n = -n + end + return n + end + return tointeger(x) + end + local bint_tointeger = bint.tointeger + + local function bint_assert_tointeger(x) + x = bint_tointeger(x) + if not x then + error('value has no integer representation') + end + return x + end + + --- Convert a bint to a lua float in case integer would wrap around or lua integer otherwise. + -- Different from @{bint.tointeger} the operation does not wrap around integers, + -- but digits precision are lost in the process of converting to a float. + -- @param x A bint or value to be converted into a lua number. + -- @return A lua number or nil in case the input cannot be represented by a number. + -- @see bint.tointeger + function bint.tonumber(x) + if getmetatable(x) == bint then + if x <= BINT_MATHMAXINTEGER and x >= BINT_MATHMININTEGER then + return x:tointeger() + end + return tonumber(tostring(x)) + end + return tonumber(x) + end + local bint_tonumber = bint.tonumber + + -- Compute base letters to use in bint.tobase + local BASE_LETTERS = {} + do + for i = 1, 36 do + BASE_LETTERS[i - 1] = ('0123456789abcdefghijklmnopqrstuvwxyz'):sub(i, i) + end + end + + --- Convert a bint to a string in the desired base. + -- @param x The bint to be converted from. + -- @param[opt] base Base to be represented, defaults to 10. + -- Must be at least 2 and at most 36. + -- @param[opt] unsigned Whether to output as an unsigned integer. + -- Defaults to false for base 10 and true for others. + -- When unsigned is false the symbol '-' is prepended in negative values. + -- @return A string representing the input. + -- @raise An assert is thrown in case the base is invalid. + function bint.tobase(x, base, unsigned) + x = tobint(x) + if not x then + -- x is a fractional float or something else + return + end + base = base or 10 + if not (base >= 2 and base <= 36) then + -- number base is too large + return + end + if unsigned == nil then + unsigned = base ~= 10 + end + local isxneg = x:isneg() + if (base == 10 and not unsigned) or (base == 16 and unsigned and not isxneg) then + if x <= BINT_MATHMAXINTEGER and x >= BINT_MATHMININTEGER then + -- integer is small, use tostring or string.format (faster) + local n = x:tointeger() + if base == 10 then + return tostring(n) + elseif unsigned then + return string_format('%x', n) + end + end + end + local ss = {} + local neg = not unsigned and isxneg + x = neg and x:abs() or bint_new(x) + local xiszero = x:iszero() + if xiszero then + return '0' + end + -- calculate basepow + local step = 0 + local basepow = 1 + local limit = (BINT_WORDMSB - 1) // base + repeat + step = step + 1 + basepow = basepow * base + until basepow >= limit + -- serialize base digits + local size = BINT_SIZE + local xd, carry, d + repeat + -- single word division + carry = 0 + xiszero = true + for i = size, 1, -1 do + carry = carry | x[i] + d, xd = carry // basepow, carry % basepow + if xiszero and d ~= 0 then + size = i + xiszero = false + end + x[i] = d + carry = xd << BINT_WORDBITS + end + -- digit division + for _ = 1, step do + xd, d = xd // base, xd % base + if xiszero and xd == 0 and d == 0 then + -- stop on leading zeros + break + end + table_insert(ss, 1, BASE_LETTERS[d]) + end + until xiszero + if neg then + table_insert(ss, 1, '-') + end + return table_concat(ss) + end + + local function bint_assert_convert(x) + return assert(tobint(x), 'value has not integer representation') + end + + --- Check if a number is 0 considering bints. + -- @param x A bint or a lua number. + function bint.iszero(x) + if getmetatable(x) == bint then + for i = 1, BINT_SIZE do + if x[i] ~= 0 then + return false + end + end + return true + end + return x == 0 + end + + --- Check if a number is 1 considering bints. + -- @param x A bint or a lua number. + function bint.isone(x) + if getmetatable(x) == bint then + if x[1] ~= 1 then + return false + end + for i = 2, BINT_SIZE do + if x[i] ~= 0 then + return false + end + end + return true + end + return x == 1 + end + + --- Check if a number is -1 considering bints. + -- @param x A bint or a lua number. + function bint.isminusone(x) + if getmetatable(x) == bint then + for i = 1, BINT_SIZE do + if x[i] ~= BINT_WORDMAX then + return false + end + end + return true + end + return x == -1 + end + local bint_isminusone = bint.isminusone + + --- Check if the input is a bint. + -- @param x Any lua value. + function bint.isbint(x) + return getmetatable(x) == bint + end + + --- Check if the input is a lua integer or a bint. + -- @param x Any lua value. + function bint.isintegral(x) + return getmetatable(x) == bint or math_type(x) == 'integer' + end + + --- Check if the input is a bint or a lua number. + -- @param x Any lua value. + function bint.isnumeric(x) + return getmetatable(x) == bint or type(x) == 'number' + end + + --- Get the number type of the input (bint, integer or float). + -- @param x Any lua value. + -- @return Returns "bint" for bints, "integer" for lua integers, + -- "float" from lua floats or nil otherwise. + function bint.type(x) + if getmetatable(x) == bint then + return 'bint' + end + return math_type(x) + end + + --- Check if a number is negative considering bints. + -- Zero is guaranteed to never be negative for bints. + -- @param x A bint or a lua number. + function bint.isneg(x) + if getmetatable(x) == bint then + return x[BINT_SIZE] & BINT_WORDMSB ~= 0 + end + return x < 0 + end + local bint_isneg = bint.isneg + + --- Check if a number is positive considering bints. + -- @param x A bint or a lua number. + function bint.ispos(x) + if getmetatable(x) == bint then + return not x:isneg() and not x:iszero() + end + return x > 0 + end + + --- Check if a number is even considering bints. + -- @param x A bint or a lua number. + function bint.iseven(x) + if getmetatable(x) == bint then + return x[1] & 1 == 0 + end + return math_abs(x) % 2 == 0 + end + + --- Check if a number is odd considering bints. + -- @param x A bint or a lua number. + function bint.isodd(x) + if getmetatable(x) == bint then + return x[1] & 1 == 1 + end + return math_abs(x) % 2 == 1 + end + + --- Create a new bint with the maximum possible integer value. + function bint.maxinteger() + local x = setmetatable({}, bint) + for i = 1, BINT_SIZE - 1 do + x[i] = BINT_WORDMAX + end + x[BINT_SIZE] = BINT_WORDMAX ~ BINT_WORDMSB + return x + end + + --- Create a new bint with the minimum possible integer value. + function bint.mininteger() + local x = setmetatable({}, bint) + for i = 1, BINT_SIZE - 1 do + x[i] = 0 + end + x[BINT_SIZE] = BINT_WORDMSB + return x + end + + --- Bitwise left shift a bint in one bit (in-place). + function bint:_shlone() + local wordbitsm1 = BINT_WORDBITS - 1 + for i = BINT_SIZE, 2, -1 do + self[i] = ((self[i] << 1) | (self[i - 1] >> wordbitsm1)) & BINT_WORDMAX + end + self[1] = (self[1] << 1) & BINT_WORDMAX + return self + end + + --- Bitwise right shift a bint in one bit (in-place). + function bint:_shrone() + local wordbitsm1 = BINT_WORDBITS - 1 + for i = 1, BINT_SIZE - 1 do + self[i] = ((self[i] >> 1) | (self[i + 1] << wordbitsm1)) & BINT_WORDMAX + end + self[BINT_SIZE] = self[BINT_SIZE] >> 1 + return self + end + + -- Bitwise left shift words of a bint (in-place). Used only internally. + function bint:_shlwords(n) + for i = BINT_SIZE, n + 1, -1 do + self[i] = self[i - n] + end + for i = 1, n do + self[i] = 0 + end + return self + end + + -- Bitwise right shift words of a bint (in-place). Used only internally. + function bint:_shrwords(n) + if n < BINT_SIZE then + for i = 1, BINT_SIZE - n do + self[i] = self[i + n] + end + for i = BINT_SIZE - n + 1, BINT_SIZE do + self[i] = 0 + end + else + for i = 1, BINT_SIZE do + self[i] = 0 + end + end + return self + end + + --- Increment a bint by one (in-place). + function bint:_inc() + for i = 1, BINT_SIZE do + local tmp = self[i] + local v = (tmp + 1) & BINT_WORDMAX + self[i] = v + if v > tmp then + break + end + end + return self + end + + --- Increment a number by one considering bints. + -- @param x A bint or a lua number to increment. + function bint.inc(x) + local ix = tobint(x, true) + if ix then + return ix:_inc() + end + return x + 1 + end + + --- Decrement a bint by one (in-place). + function bint:_dec() + for i = 1, BINT_SIZE do + local tmp = self[i] + local v = (tmp - 1) & BINT_WORDMAX + self[i] = v + if not (v > tmp) then + break + end + end + return self + end + + --- Decrement a number by one considering bints. + -- @param x A bint or a lua number to decrement. + function bint.dec(x) + local ix = tobint(x, true) + if ix then + return ix:_dec() + end + return x - 1 + end + + --- Assign a bint to a new value (in-place). + -- @param y A value to be copied from. + -- @raise Asserts in case inputs are not convertible to integers. + function bint:_assign(y) + y = bint_assert_convert(y) + for i = 1, BINT_SIZE do + self[i] = y[i] + end + return self + end + + --- Take absolute of a bint (in-place). + function bint:_abs() + if self:isneg() then + self:_unm() + end + return self + end + + --- Take absolute of a number considering bints. + -- @param x A bint or a lua number to take the absolute. + function bint.abs(x) + local ix = tobint(x, true) + if ix then + return ix:_abs() + end + return math_abs(x) + end + local bint_abs = bint.abs + + --- Take the floor of a number considering bints. + -- @param x A bint or a lua number to perform the floor operation. + function bint.floor(x) + if getmetatable(x) == bint then + return bint_new(x) + end + return bint_new(math_floor(tonumber(x))) + end + + --- Take ceil of a number considering bints. + -- @param x A bint or a lua number to perform the ceil operation. + function bint.ceil(x) + if getmetatable(x) == bint then + return bint_new(x) + end + return bint_new(math_ceil(tonumber(x))) + end + + --- Wrap around bits of an integer (discarding left bits) considering bints. + -- @param x A bint or a lua integer. + -- @param y Number of right bits to preserve. + function bint.bwrap(x, y) + x = bint_assert_convert(x) + if y <= 0 then + return bint_zero() + elseif y < BINT_BITS then + return x & (bint_one() << y):_dec() + end + return bint_new(x) + end + + --- Rotate left integer x by y bits considering bints. + -- @param x A bint or a lua integer. + -- @param y Number of bits to rotate. + function bint.brol(x, y) + x, y = bint_assert_convert(x), bint_assert_tointeger(y) + if y > 0 then + return (x << y) | (x >> (BINT_BITS - y)) + elseif y < 0 then + return x:bror(-y) + end + return x + end + + --- Rotate right integer x by y bits considering bints. + -- @param x A bint or a lua integer. + -- @param y Number of bits to rotate. + function bint.bror(x, y) + x, y = bint_assert_convert(x), bint_assert_tointeger(y) + if y > 0 then + return (x >> y) | (x << (BINT_BITS - y)) + elseif y < 0 then + return x:brol(-y) + end + return x + end + + --- Truncate a number to a bint. + -- Floats numbers are truncated, that is, the fractional port is discarded. + -- @param x A number to truncate. + -- @return A new bint or nil in case the input does not fit in a bint or is not a number. + function bint.trunc(x) + if getmetatable(x) ~= bint then + x = tonumber(x) + if x then + local ty = math_type(x) + if ty == 'float' then + -- truncate to integer + x = math_modf(x) + end + return bint_frominteger(x) + end + return + end + return bint_new(x) + end + + --- Take maximum between two numbers considering bints. + -- @param x A bint or lua number to compare. + -- @param y A bint or lua number to compare. + -- @return A bint or a lua number. Guarantees to return a new bint for integer values. + function bint.max(x, y) + local ix, iy = tobint(x), tobint(y) + if ix and iy then + return bint_new(ix > iy and ix or iy) + end + return bint_parse(math_max(x, y)) + end + + --- Take minimum between two numbers considering bints. + -- @param x A bint or lua number to compare. + -- @param y A bint or lua number to compare. + -- @return A bint or a lua number. Guarantees to return a new bint for integer values. + function bint.min(x, y) + local ix, iy = tobint(x), tobint(y) + if ix and iy then + return bint_new(ix < iy and ix or iy) + end + return bint_parse(math_min(x, y)) + end + + --- Add an integer to a bint (in-place). + -- @param y An integer to be added. + -- @raise Asserts in case inputs are not convertible to integers. + function bint:_add(y) + y = bint_assert_convert(y) + local carry = 0 + for i = 1, BINT_SIZE do + local tmp = self[i] + y[i] + carry + carry = tmp >> BINT_WORDBITS + self[i] = tmp & BINT_WORDMAX + end + return self + end + + --- Add two numbers considering bints. + -- @param x A bint or a lua number to be added. + -- @param y A bint or a lua number to be added. + function bint.__add(x, y) + local ix, iy = tobint(x), tobint(y) + if ix and iy then + local z = setmetatable({}, bint) + local carry = 0 + for i = 1, BINT_SIZE do + local tmp = ix[i] + iy[i] + carry + carry = tmp >> BINT_WORDBITS + z[i] = tmp & BINT_WORDMAX + end + return z + end + return bint_tonumber(x) + bint_tonumber(y) + end + + --- Subtract an integer from a bint (in-place). + -- @param y An integer to subtract. + -- @raise Asserts in case inputs are not convertible to integers. + function bint:_sub(y) + y = bint_assert_convert(y) + local borrow = 0 + local wordmaxp1 = BINT_WORDMAX + 1 + for i = 1, BINT_SIZE do + local res = self[i] + wordmaxp1 - y[i] - borrow + self[i] = res & BINT_WORDMAX + borrow = (res >> BINT_WORDBITS) ~ 1 + end + return self + end + + --- Subtract two numbers considering bints. + -- @param x A bint or a lua number to be subtracted from. + -- @param y A bint or a lua number to subtract. + function bint.__sub(x, y) + local ix, iy = tobint(x), tobint(y) + if ix and iy then + local z = setmetatable({}, bint) + local borrow = 0 + local wordmaxp1 = BINT_WORDMAX + 1 + for i = 1, BINT_SIZE do + local res = ix[i] + wordmaxp1 - iy[i] - borrow + z[i] = res & BINT_WORDMAX + borrow = (res >> BINT_WORDBITS) ~ 1 + end + return z + end + return bint_tonumber(x) - bint_tonumber(y) + end + + --- Multiply two numbers considering bints. + -- @param x A bint or a lua number to multiply. + -- @param y A bint or a lua number to multiply. + function bint.__mul(x, y) + local ix, iy = tobint(x), tobint(y) + if ix and iy then + local z = bint_zero() + local sizep1 = BINT_SIZE + 1 + local s = sizep1 + local e = 0 + for i = 1, BINT_SIZE do + if ix[i] ~= 0 or iy[i] ~= 0 then + e = math_max(e, i) + s = math_min(s, i) + end + end + for i = s, e do + for j = s, math_min(sizep1 - i, e) do + local a = ix[i] * iy[j] + if a ~= 0 then + local carry = 0 + for k = i + j - 1, BINT_SIZE do + local tmp = z[k] + (a & BINT_WORDMAX) + carry + carry = tmp >> BINT_WORDBITS + z[k] = tmp & BINT_WORDMAX + a = a >> BINT_WORDBITS + end + end + end + end + return z + end + return bint_tonumber(x) * bint_tonumber(y) + end + + --- Check if bints are equal. + -- @param x A bint to compare. + -- @param y A bint to compare. + function bint.__eq(x, y) + for i = 1, BINT_SIZE do + if x[i] ~= y[i] then + return false + end + end + return true + end + + --- Check if numbers are equal considering bints. + -- @param x A bint or lua number to compare. + -- @param y A bint or lua number to compare. + function bint.eq(x, y) + local ix, iy = tobint(x), tobint(y) + if ix and iy then + return ix == iy + end + return x == y + end + local bint_eq = bint.eq + + local function findleftbit(x) + for i = BINT_SIZE, 1, -1 do + local v = x[i] + if v ~= 0 then + local j = 0 + repeat + v = v >> 1 + j = j + 1 + until v == 0 + return (i - 1) * BINT_WORDBITS + j - 1, i + end + end + end + + -- Single word division modulus + local function sudivmod(nume, deno) + local rema + local carry = 0 + for i = BINT_SIZE, 1, -1 do + carry = carry | nume[i] + nume[i] = carry // deno + rema = carry % deno + carry = rema << BINT_WORDBITS + end + return rema + end + + --- Perform unsigned division and modulo operation between two integers considering bints. + -- This is effectively the same of @{bint.udiv} and @{bint.umod}. + -- @param x The numerator, must be a bint or a lua integer. + -- @param y The denominator, must be a bint or a lua integer. + -- @return The quotient following the remainder, both bints. + -- @raise Asserts on attempt to divide by zero + -- or if inputs are not convertible to integers. + -- @see bint.udiv + -- @see bint.umod + function bint.udivmod(x, y) + local nume = bint_new(x) + local deno = bint_assert_convert(y) + -- compute if high bits of denominator are all zeros + local ishighzero = true + for i = 2, BINT_SIZE do + if deno[i] ~= 0 then + ishighzero = false + break + end + end + if ishighzero then + -- try to divide by a single word (optimization) + local low = deno[1] + assert(low ~= 0, 'attempt to divide by zero') + if low == 1 then + -- denominator is one + return nume, bint_zero() + elseif low <= (BINT_WORDMSB - 1) then + -- can do single word division + local rema = sudivmod(nume, low) + return nume, bint_fromuinteger(rema) + end + end + if nume:ult(deno) then + -- denominator is greater than numerator + return bint_zero(), nume + end + -- align leftmost digits in numerator and denominator + local denolbit = findleftbit(deno) + local numelbit, numesize = findleftbit(nume) + local bit = numelbit - denolbit + deno = deno << bit + local wordmaxp1 = BINT_WORDMAX + 1 + local wordbitsm1 = BINT_WORDBITS - 1 + local denosize = numesize + local quot = bint_zero() + while bit >= 0 do + -- compute denominator <= numerator + local le = true + local size = math_max(numesize, denosize) + for i = size, 1, -1 do + local a, b = deno[i], nume[i] + if a ~= b then + le = a < b + break + end + end + -- if the portion of the numerator above the denominator is greater or equal than to the denominator + if le then + -- subtract denominator from the portion of the numerator + local borrow = 0 + for i = 1, size do + local res = nume[i] + wordmaxp1 - deno[i] - borrow + nume[i] = res & BINT_WORDMAX + borrow = (res >> BINT_WORDBITS) ~ 1 + end + -- concatenate 1 to the right bit of the quotient + local i = (bit // BINT_WORDBITS) + 1 + quot[i] = quot[i] | (1 << (bit % BINT_WORDBITS)) + end + -- shift right the denominator in one bit + for i = 1, denosize - 1 do + deno[i] = ((deno[i] >> 1) | (deno[i + 1] << wordbitsm1)) & BINT_WORDMAX + end + local lastdenoword = deno[denosize] >> 1 + deno[denosize] = lastdenoword + -- recalculate denominator size (optimization) + if lastdenoword == 0 then + while deno[denosize] == 0 do + denosize = denosize - 1 + end + if denosize == 0 then + break + end + end + -- decrement current set bit for the quotient + bit = bit - 1 + end + -- the remaining numerator is the remainder + return quot, nume + end + local bint_udivmod = bint.udivmod + + --- Perform unsigned division between two integers considering bints. + -- @param x The numerator, must be a bint or a lua integer. + -- @param y The denominator, must be a bint or a lua integer. + -- @return The quotient, a bint. + -- @raise Asserts on attempt to divide by zero + -- or if inputs are not convertible to integers. + function bint.udiv(x, y) + return (bint_udivmod(x, y)) + end + + --- Perform unsigned integer modulo operation between two integers considering bints. + -- @param x The numerator, must be a bint or a lua integer. + -- @param y The denominator, must be a bint or a lua integer. + -- @return The remainder, a bint. + -- @raise Asserts on attempt to divide by zero + -- or if the inputs are not convertible to integers. + function bint.umod(x, y) + local _, rema = bint_udivmod(x, y) + return rema + end + local bint_umod = bint.umod + + --- Perform integer truncate division and modulo operation between two numbers considering bints. + -- This is effectively the same of @{bint.tdiv} and @{bint.tmod}. + -- @param x The numerator, a bint or lua number. + -- @param y The denominator, a bint or lua number. + -- @return The quotient following the remainder, both bint or lua number. + -- @raise Asserts on attempt to divide by zero or on division overflow. + -- @see bint.tdiv + -- @see bint.tmod + function bint.tdivmod(x, y) + local ax, ay = bint_abs(x), bint_abs(y) + local ix, iy = tobint(ax), tobint(ay) + local quot, rema + if ix and iy then + assert(not (bint_eq(x, BINT_MININTEGER) and bint_isminusone(y)), 'division overflow') + quot, rema = bint_udivmod(ix, iy) + else + quot, rema = ax // ay, ax % ay + end + local isxneg, isyneg = bint_isneg(x), bint_isneg(y) + if isxneg ~= isyneg then + quot = -quot + end + if isxneg then + rema = -rema + end + return quot, rema + end + local bint_tdivmod = bint.tdivmod + + --- Perform truncate division between two numbers considering bints. + -- Truncate division is a division that rounds the quotient towards zero. + -- @param x The numerator, a bint or lua number. + -- @param y The denominator, a bint or lua number. + -- @return The quotient, a bint or lua number. + -- @raise Asserts on attempt to divide by zero or on division overflow. + function bint.tdiv(x, y) + return (bint_tdivmod(x, y)) + end + + --- Perform integer truncate modulo operation between two numbers considering bints. + -- The operation is defined as the remainder of the truncate division + -- (division that rounds the quotient towards zero). + -- @param x The numerator, a bint or lua number. + -- @param y The denominator, a bint or lua number. + -- @return The remainder, a bint or lua number. + -- @raise Asserts on attempt to divide by zero or on division overflow. + function bint.tmod(x, y) + local _, rema = bint_tdivmod(x, y) + return rema + end + + --- Perform integer floor division and modulo operation between two numbers considering bints. + -- This is effectively the same of @{bint.__idiv} and @{bint.__mod}. + -- @param x The numerator, a bint or lua number. + -- @param y The denominator, a bint or lua number. + -- @return The quotient following the remainder, both bint or lua number. + -- @raise Asserts on attempt to divide by zero. + -- @see bint.__idiv + -- @see bint.__mod + function bint.idivmod(x, y) + local ix, iy = tobint(x), tobint(y) + if ix and iy then + local isnumeneg = ix[BINT_SIZE] & BINT_WORDMSB ~= 0 + local isdenoneg = iy[BINT_SIZE] & BINT_WORDMSB ~= 0 + if isnumeneg then + ix = -ix + end + if isdenoneg then + iy = -iy + end + local quot, rema = bint_udivmod(ix, iy) + if isnumeneg ~= isdenoneg then + quot:_unm() + -- round quotient towards minus infinity + if not rema:iszero() then + quot:_dec() + -- adjust the remainder + if isnumeneg and not isdenoneg then + rema:_unm():_add(y) + elseif isdenoneg and not isnumeneg then + rema:_add(y) + end + end + elseif isnumeneg then + -- adjust the remainder + rema:_unm() + end + return quot, rema + end + local nx, ny = bint_tonumber(x), bint_tonumber(y) + return nx // ny, nx % ny + end + local bint_idivmod = bint.idivmod + + --- Perform floor division between two numbers considering bints. + -- Floor division is a division that rounds the quotient towards minus infinity, + -- resulting in the floor of the division of its operands. + -- @param x The numerator, a bint or lua number. + -- @param y The denominator, a bint or lua number. + -- @return The quotient, a bint or lua number. + -- @raise Asserts on attempt to divide by zero. + function bint.__idiv(x, y) + local ix, iy = tobint(x), tobint(y) + if ix and iy then + local isnumeneg = ix[BINT_SIZE] & BINT_WORDMSB ~= 0 + local isdenoneg = iy[BINT_SIZE] & BINT_WORDMSB ~= 0 + if isnumeneg then + ix = -ix + end + if isdenoneg then + iy = -iy + end + local quot, rema = bint_udivmod(ix, iy) + if isnumeneg ~= isdenoneg then + quot:_unm() + -- round quotient towards minus infinity + if not rema:iszero() then + quot:_dec() + end + end + return quot, rema + end + return bint_tonumber(x) // bint_tonumber(y) + end + + --- Perform division between two numbers considering bints. + -- This always casts inputs to floats, for integer division only use @{bint.__idiv}. + -- @param x The numerator, a bint or lua number. + -- @param y The denominator, a bint or lua number. + -- @return The quotient, a lua number. + function bint.__div(x, y) + return bint_tonumber(x) / bint_tonumber(y) + end + + --- Perform integer floor modulo operation between two numbers considering bints. + -- The operation is defined as the remainder of the floor division + -- (division that rounds the quotient towards minus infinity). + -- @param x The numerator, a bint or lua number. + -- @param y The denominator, a bint or lua number. + -- @return The remainder, a bint or lua number. + -- @raise Asserts on attempt to divide by zero. + function bint.__mod(x, y) + local _, rema = bint_idivmod(x, y) + return rema + end + + --- Perform integer power between two integers considering bints. + -- If y is negative then pow is performed as an unsigned integer. + -- @param x The base, an integer. + -- @param y The exponent, an integer. + -- @return The result of the pow operation, a bint. + -- @raise Asserts in case inputs are not convertible to integers. + -- @see bint.__pow + -- @see bint.upowmod + function bint.ipow(x, y) + y = bint_assert_convert(y) + if y:iszero() then + return bint_one() + elseif y:isone() then + return bint_new(x) + end + -- compute exponentiation by squaring + x, y = bint_new(x), bint_new(y) + local z = bint_one() + repeat + if y:iseven() then + x = x * x + y:_shrone() + else + z = x * z + x = x * x + y:_dec():_shrone() + end + until y:isone() + return x * z + end + + --- Perform integer power between two unsigned integers over a modulus considering bints. + -- @param x The base, an integer. + -- @param y The exponent, an integer. + -- @param m The modulus, an integer. + -- @return The result of the pow operation, a bint. + -- @raise Asserts in case inputs are not convertible to integers. + -- @see bint.__pow + -- @see bint.ipow + function bint.upowmod(x, y, m) + m = bint_assert_convert(m) + if m:isone() then + return bint_zero() + end + x, y = bint_new(x), bint_new(y) + local z = bint_one() + x = bint_umod(x, m) + while not y:iszero() do + if y:isodd() then + z = bint_umod(z * x, m) + end + y:_shrone() + x = bint_umod(x * x, m) + end + return z + end + + --- Perform numeric power between two numbers considering bints. + -- This always casts inputs to floats, for integer power only use @{bint.ipow}. + -- @param x The base, a bint or lua number. + -- @param y The exponent, a bint or lua number. + -- @return The result of the pow operation, a lua number. + -- @see bint.ipow + function bint.__pow(x, y) + return bint_tonumber(x) ^ bint_tonumber(y) + end + + --- Bitwise left shift integers considering bints. + -- @param x An integer to perform the bitwise shift. + -- @param y An integer with the number of bits to shift. + -- @return The result of shift operation, a bint. + -- @raise Asserts in case inputs are not convertible to integers. + function bint.__shl(x, y) + x, y = bint_new(x), bint_assert_tointeger(y) + if y < 0 then + return x >> -y + end + local nvals = y // BINT_WORDBITS + if nvals ~= 0 then + x:_shlwords(nvals) + y = y - nvals * BINT_WORDBITS + end + if y ~= 0 then + local wordbitsmy = BINT_WORDBITS - y + for i = BINT_SIZE, 2, -1 do + x[i] = ((x[i] << y) | (x[i - 1] >> wordbitsmy)) & BINT_WORDMAX + end + x[1] = (x[1] << y) & BINT_WORDMAX + end + return x + end + + --- Bitwise right shift integers considering bints. + -- @param x An integer to perform the bitwise shift. + -- @param y An integer with the number of bits to shift. + -- @return The result of shift operation, a bint. + -- @raise Asserts in case inputs are not convertible to integers. + function bint.__shr(x, y) + x, y = bint_new(x), bint_assert_tointeger(y) + if y < 0 then + return x << -y + end + local nvals = y // BINT_WORDBITS + if nvals ~= 0 then + x:_shrwords(nvals) + y = y - nvals * BINT_WORDBITS + end + if y ~= 0 then + local wordbitsmy = BINT_WORDBITS - y + for i = 1, BINT_SIZE - 1 do + x[i] = ((x[i] >> y) | (x[i + 1] << wordbitsmy)) & BINT_WORDMAX + end + x[BINT_SIZE] = x[BINT_SIZE] >> y + end + return x + end + + --- Bitwise AND bints (in-place). + -- @param y An integer to perform bitwise AND. + -- @raise Asserts in case inputs are not convertible to integers. + function bint:_band(y) + y = bint_assert_convert(y) + for i = 1, BINT_SIZE do + self[i] = self[i] & y[i] + end + return self + end + + --- Bitwise AND two integers considering bints. + -- @param x An integer to perform bitwise AND. + -- @param y An integer to perform bitwise AND. + -- @raise Asserts in case inputs are not convertible to integers. + function bint.__band(x, y) + return bint_new(x):_band(y) + end + + --- Bitwise OR bints (in-place). + -- @param y An integer to perform bitwise OR. + -- @raise Asserts in case inputs are not convertible to integers. + function bint:_bor(y) + y = bint_assert_convert(y) + for i = 1, BINT_SIZE do + self[i] = self[i] | y[i] + end + return self + end + + --- Bitwise OR two integers considering bints. + -- @param x An integer to perform bitwise OR. + -- @param y An integer to perform bitwise OR. + -- @raise Asserts in case inputs are not convertible to integers. + function bint.__bor(x, y) + return bint_new(x):_bor(y) + end + + --- Bitwise XOR bints (in-place). + -- @param y An integer to perform bitwise XOR. + -- @raise Asserts in case inputs are not convertible to integers. + function bint:_bxor(y) + y = bint_assert_convert(y) + for i = 1, BINT_SIZE do + self[i] = self[i] ~ y[i] + end + return self + end + + --- Bitwise XOR two integers considering bints. + -- @param x An integer to perform bitwise XOR. + -- @param y An integer to perform bitwise XOR. + -- @raise Asserts in case inputs are not convertible to integers. + function bint.__bxor(x, y) + return bint_new(x):_bxor(y) + end + + --- Bitwise NOT a bint (in-place). + function bint:_bnot() + for i = 1, BINT_SIZE do + self[i] = (~self[i]) & BINT_WORDMAX + end + return self + end + + --- Bitwise NOT a bint. + -- @param x An integer to perform bitwise NOT. + -- @raise Asserts in case inputs are not convertible to integers. + function bint.__bnot(x) + local y = setmetatable({}, bint) + for i = 1, BINT_SIZE do + y[i] = (~x[i]) & BINT_WORDMAX + end + return y + end + + --- Negate a bint (in-place). This effectively applies two's complements. + function bint:_unm() + return self:_bnot():_inc() + end + + --- Negate a bint. This effectively applies two's complements. + -- @param x A bint to perform negation. + function bint.__unm(x) + return (~x):_inc() + end + + --- Compare if integer x is less than y considering bints (unsigned version). + -- @param x Left integer to compare. + -- @param y Right integer to compare. + -- @raise Asserts in case inputs are not convertible to integers. + -- @see bint.__lt + function bint.ult(x, y) + x, y = bint_assert_convert(x), bint_assert_convert(y) + for i = BINT_SIZE, 1, -1 do + local a, b = x[i], y[i] + if a ~= b then + return a < b + end + end + return false + end + + --- Compare if bint x is less or equal than y considering bints (unsigned version). + -- @param x Left integer to compare. + -- @param y Right integer to compare. + -- @raise Asserts in case inputs are not convertible to integers. + -- @see bint.__le + function bint.ule(x, y) + x, y = bint_assert_convert(x), bint_assert_convert(y) + for i = BINT_SIZE, 1, -1 do + local a, b = x[i], y[i] + if a ~= b then + return a < b + end + end + return true + end + + --- Compare if number x is less than y considering bints and signs. + -- @param x Left value to compare, a bint or lua number. + -- @param y Right value to compare, a bint or lua number. + -- @see bint.ult + function bint.__lt(x, y) + local ix, iy = tobint(x), tobint(y) + if ix and iy then + local xneg = ix[BINT_SIZE] & BINT_WORDMSB ~= 0 + local yneg = iy[BINT_SIZE] & BINT_WORDMSB ~= 0 + if xneg == yneg then + for i = BINT_SIZE, 1, -1 do + local a, b = ix[i], iy[i] + if a ~= b then + return a < b + end + end + return false + end + return xneg and not yneg + end + return bint_tonumber(x) < bint_tonumber(y) + end + + --- Compare if number x is less or equal than y considering bints and signs. + -- @param x Left value to compare, a bint or lua number. + -- @param y Right value to compare, a bint or lua number. + -- @see bint.ule + function bint.__le(x, y) + local ix, iy = tobint(x), tobint(y) + if ix and iy then + local xneg = ix[BINT_SIZE] & BINT_WORDMSB ~= 0 + local yneg = iy[BINT_SIZE] & BINT_WORDMSB ~= 0 + if xneg == yneg then + for i = BINT_SIZE, 1, -1 do + local a, b = ix[i], iy[i] + if a ~= b then + return a < b + end + end + return true + end + return xneg and not yneg + end + return bint_tonumber(x) <= bint_tonumber(y) + end + + --- Convert a bint to a string on base 10. + -- @see bint.tobase + function bint:__tostring() + return self:tobase(10) + end + + -- Allow creating bints by calling bint itself + setmetatable(bint, { + __call = function(_, x) + return bint_new(x) + end, + }) + + BINT_MATHMININTEGER, BINT_MATHMAXINTEGER = bint_new(math.mininteger), bint_new(math.maxinteger) + BINT_MININTEGER = bint.mininteger() + memo[memoindex] = bint + + return bint + +end + +return newmodule diff --git a/framework/lualib/thirdparty/lester/lester.lua b/framework/lualib/thirdparty/lester/lester.lua new file mode 100644 index 0000000..64df041 --- /dev/null +++ b/framework/lualib/thirdparty/lester/lester.lua @@ -0,0 +1,473 @@ +--[[ +Minimal test framework for Lua. +lester - v0.1.2 - 15/Feb/2021 +Eduardo Bart - edub4rt@gmail.com +https://github.com/edubart/lester +Minimal Lua test framework. +See end of file for LICENSE. +]] --[[-- +Lester is a minimal unit testing framework for Lua with a focus on being simple to use. + +## Features + +* Minimal, just one file. +* Self contained, no external dependencies. +* Simple and hackable when needed. +* Use `describe` and `it` blocks to describe tests. +* Supports `before` and `after` handlers. +* Colored output. +* Configurable via the script or with environment variables. +* Quiet mode, to use in live development. +* Optionally filter tests by name. +* Show traceback on errors. +* Show time to complete tests. +* Works with Lua 5.1+. +* Efficient. + +## Usage + +Copy `lester.lua` file to a project and require it, +which returns a table that includes all of the functionality: + +```lua +local lester = require 'lester' +local describe, it, expect = lester.describe, lester.it, lester.expect + +-- Customize lester configuration. +lester.show_traceback = false + +describe('my project', function() + lester.before(function() + -- This function is run before every test. + end) + + describe('module1', function() -- Describe blocks can be nested. + it('feature1', function() + expect.equal('something', 'something') -- Pass. + end) + + it('feature2', function() + expect.truthy(false) -- Fail. + end) + end) +end) + +lester.report() -- Print overall statistic of the tests run. +lester.exit() -- Exit with success if all tests passed. +``` + +## Customizing output with environment variables + +To customize the output of lester externally, +you can set the following environment variables before running a test suite: + +* `LESTER_QUIET="true"`, omit print of passed tests. +* `LESTER_COLORED="false"`, disable colored output. +* `LESTER_SHOW_TRACEBACK="false"`, disable traceback on test failures. +* `LESTER_SHOW_ERROR="false"`, omit print of error description of failed tests. +* `LESTER_STOP_ON_FAIL="true"`, stop on first test failure. +* `LESTER_UTF8TERM="false"`, disable printing of UTF-8 characters. +* `LESTER_FILTER="some text"`, filter the tests that should be run. + +Note that these configurations can be changed via script too, check the documentation. + +]] -- Returns whether the terminal supports UTF-8 characters. +local function is_utf8term() + local lang = os.getenv('LANG') + return (lang and lang:lower():match('utf%-8$')) and true or false +end + +-- Returns whether a system environment variable is "true". +local function getboolenv(varname, default) + local val = os.getenv(varname) + if val == 'true' then + return true + elseif val == 'false' then + return false + end + return default +end + +-- The lester module. +local lester = { + --- Weather lines of passed tests should not be printed. False by default. + quiet = getboolenv('LESTER_QUIET', false), + --- Weather the output should be colorized. True by default. + colored = getboolenv('LESTER_COLORED', true), + --- Weather a traceback must be shown on test failures. True by default. + show_traceback = getboolenv('LESTER_SHOW_TRACEBACK', true), + --- Weather the error description of a test failure should be shown. True by default. + show_error = getboolenv('LESTER_SHOW_ERROR', true), + --- Weather test suite should exit on first test failure. False by default. + stop_on_fail = getboolenv('LESTER_STOP_ON_FAIL', false), + --- Weather we can print UTF-8 characters to the terminal. True by default when supported. + utf8term = getboolenv('LESTER_UTF8TERM', is_utf8term()), + --- A string with a lua pattern to filter tests. Nil by default. + filter = os.getenv('LESTER_FILTER'), + --- Function to retrieve time in seconds with milliseconds precision, `os.clock` by default. + seconds = os.clock, +} + +-- Variables used internally for the lester state. +local lester_start = nil +local last_succeeded = false +local level = 0 +local successes = 0 +local total_successes = 0 +local failures = 0 +local total_failures = 0 +local start = 0 +local befores = {} +local afters = {} +local names = {} + +-- Color codes. +local color_codes = { + reset = string.char(27) .. '[0m', + bright = string.char(27) .. '[1m', + red = string.char(27) .. '[31m', + green = string.char(27) .. '[32m', + blue = string.char(27) .. '[34m', + magenta = string.char(27) .. '[35m', +} + +-- Colors table, returning proper color code if colored mode is enabled. +local colors = setmetatable({}, { + __index = function(_, key) + return lester.colored and color_codes[key] or '' + end, +}) + +--- Table of terminal colors codes, can be customized. +lester.colors = colors + +--- Describe a block of tests, which consists in a set of tests. +-- Describes can be nested. +-- @param name A string used to describe the block. +-- @param func A function containing all the tests or other describes. +function lester.describe(name, func) + if level == 0 then -- Get start time for top level describe blocks. + start = lester.seconds() + if not lester_start then + lester_start = start + end + end + -- Setup describe block variables. + failures = 0 + successes = 0 + level = level + 1 + names[level] = name + -- Run the describe block. + func() + -- Cleanup describe block. + afters[level] = nil + befores[level] = nil + names[level] = nil + level = level - 1 + -- Pretty print statistics for top level describe block. + if level == 0 and not lester.quiet and (successes > 0 or failures > 0) then + local io_write = io.write + local colors_reset, colors_green = colors.reset, colors.green + io_write(failures == 0 and colors_green or colors.red, '[====] ', colors.magenta, name, colors_reset, ' | ', + colors_green, successes, colors_reset, ' successes / ') + if failures > 0 then + io_write(colors.red, failures, colors_reset, ' failures / ') + end + io_write(colors.bright, string.format('%.6f', lester.seconds() - start), colors_reset, ' seconds\n') + end +end + +-- Error handler used to get traceback for errors. +local function xpcall_error_handler(err) + return debug.traceback(tostring(err), 2) +end + +-- Pretty print the line on the test file where an error happened. +local function show_error_line(err) + local info = debug.getinfo(3) + local io_write = io.write + local colors_reset = colors.reset + local short_src, currentline = info.short_src, info.currentline + io_write(' (', colors.blue, short_src, colors_reset, ':', colors.bright, currentline, colors_reset) + if err and lester.show_traceback then + local fnsrc = short_src .. ':' .. currentline + for cap1, cap2 in err:gmatch('\t[^\n:]+:(%d+): in function <([^>]+)>\n') do + if cap2 == fnsrc then + io_write('/', colors.bright, cap1, colors_reset) + break + end + end + end + io_write(')') +end + +-- Pretty print the test name, with breadcrumb for the describe blocks. +local function show_test_name(name) + local io_write = io.write + local colors_reset = colors.reset + for _, descname in ipairs(names) do + io_write(colors.magenta, descname, colors_reset, ' | ') + end + io_write(colors.bright, name, colors_reset) +end + +--- Declare a test, which consists of a set of assertions. +-- @param name A name for the test. +-- @param func The function containing all assertions. +function lester.it(name, func) + -- Skip the test if it does not match the filter. + if lester.filter then + local fullname = table.concat(names, ' | ') .. ' | ' .. name + if not fullname:match(lester.filter) then + return + end + end + -- Execute before handlers. + for _, levelbefores in ipairs(befores) do + for _, beforefn in ipairs(levelbefores) do + beforefn(name) + end + end + -- Run the test, capturing errors if any. + local success, err + if lester.show_traceback then + success, err = xpcall(func, xpcall_error_handler) + else + success, err = pcall(func) + if not success and err then + err = tostring(err) + end + end + -- Count successes and failures. + if success then + successes = successes + 1 + total_successes = total_successes + 1 + else + failures = failures + 1 + total_failures = total_failures + 1 + end + local io_write = io.write + local colors_reset = colors.reset + -- Print the test run. + if not lester.quiet then -- Show test status and complete test name. + if success then + io_write(colors.green, '[PASS] ', colors_reset) + else + io_write(colors.red, '[FAIL] ', colors_reset) + end + show_test_name(name) + if not success then + show_error_line(err) + end + io_write('\n') + else + if success then -- Show just a character hinting that the test succeeded. + local o = (lester.utf8term and lester.colored) and string.char(226, 151, 143) or 'o' + io_write(colors.green, o, colors_reset) + else -- Show complete test name on failure. + io_write(last_succeeded and '\n' or '', colors.red, '[FAIL] ', colors_reset) + show_test_name(name) + show_error_line(err) + io_write('\n') + end + end + -- Print error message, colorizing its output if possible. + if err and lester.show_error then + if lester.colored then + local errfile, errline, errmsg, rest = err:match('^([^:\n]+):(%d+): ([^\n]+)(.*)') + if errfile and errline and errmsg and rest then + io_write(colors.blue, errfile, colors_reset, ':', colors.bright, errline, colors_reset, ': ') + if errmsg:match('^%w([^:]*)$') then + io_write(colors.red, errmsg, colors_reset) + else + io_write(errmsg) + end + err = rest + end + end + io_write(err, '\n\n') + end + io.flush() + -- Stop on failure. + if not success and lester.stop_on_fail then + if lester.quiet then + io_write('\n') + io.flush() + end + lester.exit() + end + -- Execute after handlers. + for _, levelafters in ipairs(afters) do + for _, afterfn in ipairs(levelafters) do + afterfn(name) + end + end + last_succeeded = success +end + +--- Set a function that is called before every test inside a describe block. +-- A single string containing the name of the test about to be run will be passed to `func`. +function lester.before(func) + local levelbefores = befores[level] + if not levelbefores then + levelbefores = {} + befores[level] = levelbefores + end + levelbefores[#levelbefores + 1] = func +end + +--- Set a function that is called after every test inside a describe block. +-- A single string containing the name of the test that was finished will be passed to `func`. +-- The function is executed independently if the test passed or failed. +function lester.after(func) + local levelafters = afters[level] + if not levelafters then + levelafters = {} + afters[level] = levelafters + end + levelafters[#levelafters + 1] = func +end + +--- Pretty print statistics of all test runs. +-- With total success, total failures and run time in seconds. +function lester.report() + local now = lester.seconds() + local colors_reset = colors.reset + io.write(lester.quiet and '\n' or '', colors.green, total_successes, colors_reset, ' successes / ', colors.red, + total_failures, colors_reset, ' failures / ', colors.bright, string.format('%.6f', now - (lester_start or now)), + colors_reset, ' seconds\n') + io.flush() + return total_failures == 0 +end + +--- Exit the application with success code if all tests passed, or failure code otherwise. +function lester.exit() + os.exit(total_failures == 0) +end + +local expect = {} +--- Expect module, containing utility function for doing assertions inside a test. +lester.expect = expect + +--- Check if a function fails with an error. +-- If `expected` is nil then any error is accepted. +-- If `expected` is a string then we check if the error contains that string. +-- If `expected` is anything else then we check if both are equal. +function expect.fail(func, expected) + local ok, err = pcall(func) + if ok then + error('expected function to fail', 2) + elseif expected ~= nil then + local found = expected == err + if not found and type(expected) == 'string' then + found = string.find(tostring(err), expected, 1, true) + end + if not found then + error('expected function to fail\nexpected:\n' .. tostring(expected) .. '\ngot:\n' .. tostring(err), 2) + end + end +end + +--- Check if a function does not fail with a error. +function expect.not_fail(func) + local ok, err = pcall(func) + if not ok then + error('expected function to not fail\ngot error:\n' .. tostring(err), 2) + end +end + +--- Check if a value is not `nil`. +function expect.exist(v) + if v == nil then + error('expected value to exist\ngot:\n' .. tostring(v), 2) + end +end + +--- Check if a value is `nil`. +function expect.not_exist(v) + if v ~= nil then + error('expected value to not exist\ngot:\n' .. tostring(v), 2) + end +end + +--- Check if an expression is evaluates to `true`. +function expect.truthy(v) + if not v then + error('expected expression to be true\ngot:\n' .. tostring(v), 2) + end +end + +--- Check if an expression is evaluates to `false`. +function expect.falsy(v) + if v then + error('expected expression to be false\ngot:\n' .. tostring(v), 2) + end +end + +--- Compare if two values are equal, considering nested tables. +local function strict_eq(t1, t2) + if rawequal(t1, t2) then + return true + end + if type(t1) ~= type(t2) then + return false + end + if type(t1) ~= 'table' then + return t1 == t2 + end + if getmetatable(t1) ~= getmetatable(t2) then + return false + end + for k, v1 in pairs(t1) do + if not strict_eq(v1, t2[k]) then + return false + end + end + for k, v2 in pairs(t2) do + if not strict_eq(v2, t1[k]) then + return false + end + end + return true +end + +--- Check if two values are equal. +function expect.equal(v1, v2) + if not strict_eq(v1, v2) then + error('expected values to be equal\nfirst value:\n' .. tostring(v1) .. '\nsecond value:\n' .. tostring(v2), 2) + end +end + +--- Check if two values are not equal. +function expect.not_equal(v1, v2) + if strict_eq(v1, v2) then + error('expected values to be not equal\nfirst value:\n' .. tostring(v1) .. '\nsecond value:\n' .. tostring(v2), + 2) + end +end + +return lester + +--[[ + The MIT License (MIT) + + Copyright (c) 2021 Eduardo Bart (https://github.com/edubart) + + Permission is hereby granted, free of charge, to any person obtaining a copy + of this software and associated documentation files (the "Software"), to deal + in the Software without restriction, including without limitation the rights + to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + copies of the Software, and to permit persons to whom the Software is + furnished to do so, subject to the following conditions: + + The above copyright notice and this permission notice shall be included in all + copies or substantial portions of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + SOFTWARE. + ]] diff --git a/framework/lualib/thirdparty/luassert/array.lua b/framework/lualib/thirdparty/luassert/array.lua new file mode 100755 index 0000000..ad6c9fd --- /dev/null +++ b/framework/lualib/thirdparty/luassert/array.lua @@ -0,0 +1,70 @@ +local assert = require('luassert.assert') +local say = require('luassert.say') + +-- Example usage: +-- local arr = { "one", "two", "three" } +-- +-- assert.array(arr).has.no.holes() -- checks the array to not contain holes --> passes +-- assert.array(arr).has.no.holes(4) -- sets explicit length to 4 --> fails +-- +-- local first_hole = assert.array(arr).has.holes(4) -- check array of size 4 to contain holes --> passes +-- assert.equal(4, first_hole) -- passes, as the index of the first hole is returned + + +-- Unique key to store the object we operate on in the state object +-- key must be unique, to make sure we do not have name collissions in the shared state object +local ARRAY_STATE_KEY = "__array_state" + +-- The modifier, to store the object in our state +local function array(state, args, level) + assert(args.n > 0, "No array provided to the array-modifier") + assert(rawget(state, ARRAY_STATE_KEY) == nil, "Array already set") + rawset(state, ARRAY_STATE_KEY, args[1]) + return state +end + +-- The actual assertion that operates on our object, stored via the modifier +local function holes(state, args, level) + local length = args[1] + local arr = rawget(state, ARRAY_STATE_KEY) -- retrieve previously set object + -- only check against nil, metatable types are allowed + assert(arr ~= nil, "No array set, please use the array modifier to set the array to validate") + if length == nil then + length = 0 + for i in pairs(arr) do + if type(i) == "number" and + i > length and + math.floor(i) == i then + length = i + end + end + end + assert(type(length) == "number", "expected array length to be of type 'number', got: "..tostring(length)) + -- let's do the actual assertion + local missing + for i = 1, length do + if arr[i] == nil then + missing = i + break + end + end + -- format arguments for output strings; + args[1] = missing + args.n = missing and 1 or 0 + return missing ~= nil, { missing } -- assert result + first missing index as return value +end + +-- Register the proper assertion messages +say:set("assertion.array_holes.positive", [[ +Expected array to have holes, but none was found. +]]) +say:set("assertion.array_holes.negative", [[ +Expected array to not have holes, hole found at position: %s +]]) + +-- Register the assertion, and the modifier +assert:register("assertion", "holes", holes, + "assertion.array_holes.positive", + "assertion.array_holes.negative") + +assert:register("modifier", "array", array) diff --git a/framework/lualib/thirdparty/luassert/assert.lua b/framework/lualib/thirdparty/luassert/assert.lua new file mode 100755 index 0000000..ca53e40 --- /dev/null +++ b/framework/lualib/thirdparty/luassert/assert.lua @@ -0,0 +1,180 @@ +local s = require 'luassert.say' +local astate = require 'luassert.state' +local util = require 'luassert.util' +local unpack = util.unpack +local obj -- the returned module table +local level_mt = {} + +-- list of namespaces +local namespace = require 'luassert.namespaces' + +local function geterror(assertion_message, failure_message, args) + if util.hastostring(failure_message) then + failure_message = tostring(failure_message) + elseif failure_message ~= nil then + failure_message = astate.format_argument(failure_message) + end + local message = s(assertion_message, obj:format(args)) + if message and failure_message then + message = failure_message .. "\n" .. message + end + return message or failure_message +end + +local __state_meta = { + + __call = function(self, ...) + local keys = util.extract_keys("assertion", self.tokens) + + local assertion + + for _, key in ipairs(keys) do + assertion = namespace.assertion[key] or assertion + end + + if assertion then + for _, key in ipairs(keys) do + if namespace.modifier[key] then + namespace.modifier[key].callback(self) + end + end + + local arguments = util.make_arglist(...) + local val, retargs = assertion.callback(self, arguments, util.errorlevel()) + + if not val == self.mod then + local message = assertion.positive_message + if not self.mod then + message = assertion.negative_message + end + local err = geterror(message, rawget(self,"failure_message"), arguments) + error(err or "assertion failed!", util.errorlevel()) + end + + if retargs then + return unpack(retargs) + end + return ... + else + local arguments = util.make_arglist(...) + self.tokens = {} + + for _, key in ipairs(keys) do + if namespace.modifier[key] then + namespace.modifier[key].callback(self, arguments, util.errorlevel()) + end + end + end + + return self + end, + + __index = function(self, key) + for token in key:lower():gmatch('[^_]+') do + table.insert(self.tokens, token) + end + + return self + end +} + +obj = { + state = function() return setmetatable({mod=true, tokens={}}, __state_meta) end, + + -- registers a function in namespace + register = function(self, nspace, name, callback, positive_message, negative_message) + local lowername = name:lower() + if not namespace[nspace] then + namespace[nspace] = {} + end + namespace[nspace][lowername] = { + callback = callback, + name = lowername, + positive_message=positive_message, + negative_message=negative_message + } + end, + + -- unregisters a function in a namespace + unregister = function(self, nspace, name) + local lowername = name:lower() + if not namespace[nspace] then + namespace[nspace] = {} + end + namespace[nspace][lowername] = nil + end, + + -- registers a formatter + -- a formatter takes a single argument, and converts it to a string, or returns nil if it cannot format the argument + add_formatter = function(self, callback) + astate.add_formatter(callback) + end, + + -- unregisters a formatter + remove_formatter = function(self, fmtr) + astate.remove_formatter(fmtr) + end, + + format = function(self, args) + -- args.n specifies the number of arguments in case of 'trailing nil' arguments which get lost + local nofmt = args.nofmt or {} -- arguments in this list should not be formatted + local fmtargs = args.fmtargs or {} -- additional arguments to be passed to formatter + for i = 1, (args.n or #args) do -- cannot use pairs because table might have nils + if not nofmt[i] then + local val = args[i] + local valfmt = astate.format_argument(val, nil, fmtargs[i]) + if valfmt == nil then valfmt = tostring(val) end -- no formatter found + args[i] = valfmt + end + end + return args + end, + + set_parameter = function(self, name, value) + astate.set_parameter(name, value) + end, + + get_parameter = function(self, name) + return astate.get_parameter(name) + end, + + add_spy = function(self, spy) + astate.add_spy(spy) + end, + + snapshot = function(self) + return astate.snapshot() + end, + + level = function(self, level) + return setmetatable({ + level = level + }, level_mt) + end, + + -- returns the level if a level-value, otherwise nil + get_level = function(self, level) + if getmetatable(level) ~= level_mt then + return nil -- not a valid error-level + end + return level.level + end, +} + +local __meta = { + + __call = function(self, bool, message, level, ...) + if not bool then + local err_level = (self:get_level(level) or 1) + 1 + error(message or "assertion failed!", err_level) + end + return bool , message , level , ... + end, + + __index = function(self, key) + return rawget(self, key) or self.state()[key] + end, + +} + +return setmetatable(obj, __meta) diff --git a/framework/lualib/thirdparty/luassert/assertions.lua b/framework/lualib/thirdparty/luassert/assertions.lua new file mode 100755 index 0000000..0b74834 --- /dev/null +++ b/framework/lualib/thirdparty/luassert/assertions.lua @@ -0,0 +1,328 @@ +-- module will not return anything, only register assertions with the main assert engine + +-- assertions take 2 parameters; +-- 1) state +-- 2) arguments list. The list has a member 'n' with the argument count to check for trailing nils +-- 3) level The level of the error position relative to the called function +-- returns; boolean; whether assertion passed + +local assert = require('luassert.assert') +local astate = require ('luassert.state') +local util = require ('luassert.util') +local s = require('luassert.say') + +local function format(val) + return astate.format_argument(val) or tostring(val) +end + +local function set_failure_message(state, message) + if message ~= nil then + state.failure_message = message + end +end + +local function unique(state, arguments, level) + local list = arguments[1] + local deep + local argcnt = arguments.n + if type(arguments[2]) == "boolean" or (arguments[2] == nil and argcnt > 2) then + deep = arguments[2] + set_failure_message(state, arguments[3]) + else + if type(arguments[3]) == "boolean" then + deep = arguments[3] + end + set_failure_message(state, arguments[2]) + end + for k,v in pairs(list) do + for k2, v2 in pairs(list) do + if k ~= k2 then + if deep and util.deepcompare(v, v2, true) then + return false + else + if v == v2 then + return false + end + end + end + end + end + return true +end + +local function near(state, arguments, level) + local level = (level or 1) + 1 + local argcnt = arguments.n + assert(argcnt > 2, s("assertion.internal.argtolittle", { "near", 3, tostring(argcnt) }), level) + local expected = tonumber(arguments[1]) + local actual = tonumber(arguments[2]) + local tolerance = tonumber(arguments[3]) + local numbertype = "number or object convertible to a number" + assert(expected, s("assertion.internal.badargtype", { 1, "near", numbertype, format(arguments[1]) }), level) + assert(actual, s("assertion.internal.badargtype", { 2, "near", numbertype, format(arguments[2]) }), level) + assert(tolerance, s("assertion.internal.badargtype", { 3, "near", numbertype, format(arguments[3]) }), level) + -- switch arguments for proper output message + util.tinsert(arguments, 1, util.tremove(arguments, 2)) + arguments[3] = tolerance + arguments.nofmt = arguments.nofmt or {} + arguments.nofmt[3] = true + set_failure_message(state, arguments[4]) + return (actual >= expected - tolerance and actual <= expected + tolerance) +end + +local function matches(state, arguments, level) + local level = (level or 1) + 1 + local argcnt = arguments.n + assert(argcnt > 1, s("assertion.internal.argtolittle", { "matches", 2, tostring(argcnt) }), level) + local pattern = arguments[1] + local actual = nil + if util.hastostring(arguments[2]) or type(arguments[2]) == "number" then + actual = tostring(arguments[2]) + end + local err_message + local init_arg_num = 3 + for i=3,argcnt,1 do + if arguments[i] and type(arguments[i]) ~= "boolean" and not tonumber(arguments[i]) then + if i == 3 then init_arg_num = init_arg_num + 1 end + err_message = util.tremove(arguments, i) + break + end + end + local init = arguments[3] + local plain = arguments[4] + local stringtype = "string or object convertible to a string" + assert(type(pattern) == "string", s("assertion.internal.badargtype", { 1, "matches", "string", type(arguments[1]) }), level) + assert(actual, s("assertion.internal.badargtype", { 2, "matches", stringtype, format(arguments[2]) }), level) + assert(init == nil or tonumber(init), s("assertion.internal.badargtype", { init_arg_num, "matches", "number", type(arguments[3]) }), level) + -- switch arguments for proper output message + util.tinsert(arguments, 1, util.tremove(arguments, 2)) + set_failure_message(state, err_message) + local retargs + local ok + if plain then + ok = (actual:find(pattern, init, plain) ~= nil) + retargs = ok and { pattern } or {} + else + retargs = { actual:match(pattern, init) } + ok = (retargs[1] ~= nil) + end + return ok, retargs +end + +local function equals(state, arguments, level) + local level = (level or 1) + 1 + local argcnt = arguments.n + assert(argcnt > 1, s("assertion.internal.argtolittle", { "equals", 2, tostring(argcnt) }), level) + local result = arguments[1] == arguments[2] + -- switch arguments for proper output message + util.tinsert(arguments, 1, util.tremove(arguments, 2)) + set_failure_message(state, arguments[3]) + return result +end + +local function same(state, arguments, level) + local level = (level or 1) + 1 + local argcnt = arguments.n + assert(argcnt > 1, s("assertion.internal.argtolittle", { "same", 2, tostring(argcnt) }), level) + if type(arguments[1]) == 'table' and type(arguments[2]) == 'table' then + local result, crumbs = util.deepcompare(arguments[1], arguments[2], true) + -- switch arguments for proper output message + util.tinsert(arguments, 1, util.tremove(arguments, 2)) + arguments.fmtargs = arguments.fmtargs or {} + arguments.fmtargs[1] = { crumbs = crumbs } + arguments.fmtargs[2] = { crumbs = crumbs } + set_failure_message(state, arguments[3]) + return result + end + local result = arguments[1] == arguments[2] + -- switch arguments for proper output message + util.tinsert(arguments, 1, util.tremove(arguments, 2)) + set_failure_message(state, arguments[3]) + return result +end + +local function truthy(state, arguments, level) + set_failure_message(state, arguments[2]) + return arguments[1] ~= false and arguments[1] ~= nil +end + +local function falsy(state, arguments, level) + return not truthy(state, arguments, level) +end + +local function has_error(state, arguments, level) + local level = (level or 1) + 1 + local retargs = util.shallowcopy(arguments) + local func = arguments[1] + local err_expected = arguments[2] + local failure_message = arguments[3] + assert(util.callable(func), s("assertion.internal.badargtype", { 1, "error", "function or callable object", type(func) }), level) + local ok, err_actual = pcall(func) + if type(err_actual) == 'string' then + -- remove 'path/to/file:line: ' from string + err_actual = err_actual:gsub('^.-:%d+: ', '', 1) + end + retargs[1] = err_actual + arguments.nofmt = {} + arguments.n = 2 + arguments[1] = (ok and '(no error)' or err_actual) + arguments[2] = (err_expected == nil and '(error)' or err_expected) + arguments.nofmt[1] = ok + arguments.nofmt[2] = (err_expected == nil) + set_failure_message(state, failure_message) + + if ok or err_expected == nil then + return not ok, retargs + end + if type(err_expected) == 'string' then + -- err_actual must be (convertible to) a string + if util.hastostring(err_actual) then + err_actual = tostring(err_actual) + retargs[1] = err_actual + end + if type(err_actual) == 'string' then + return err_expected == err_actual, retargs + end + elseif type(err_expected) == 'number' then + if type(err_actual) == 'string' then + return tostring(err_expected) == tostring(tonumber(err_actual)), retargs + end + end + return same(state, {err_expected, err_actual, ["n"] = 2}), retargs +end + +local function error_matches(state, arguments, level) + local level = (level or 1) + 1 + local retargs = util.shallowcopy(arguments) + local argcnt = arguments.n + local func = arguments[1] + local pattern = arguments[2] + assert(argcnt > 1, s("assertion.internal.argtolittle", { "error_matches", 2, tostring(argcnt) }), level) + assert(util.callable(func), s("assertion.internal.badargtype", { 1, "error_matches", "function or callable object", type(func) }), level) + assert(pattern == nil or type(pattern) == "string", s("assertion.internal.badargtype", { 2, "error", "string", type(pattern) }), level) + + local failure_message + local init_arg_num = 3 + for i=3,argcnt,1 do + if arguments[i] and type(arguments[i]) ~= "boolean" and not tonumber(arguments[i]) then + if i == 3 then init_arg_num = init_arg_num + 1 end + failure_message = util.tremove(arguments, i) + break + end + end + local init = arguments[3] + local plain = arguments[4] + assert(init == nil or tonumber(init), s("assertion.internal.badargtype", { init_arg_num, "matches", "number", type(arguments[3]) }), level) + + local ok, err_actual = pcall(func) + if type(err_actual) == 'string' then + -- remove 'path/to/file:line: ' from string + err_actual = err_actual:gsub('^.-:%d+: ', '', 1) + end + retargs[1] = err_actual + arguments.nofmt = {} + arguments.n = 2 + arguments[1] = (ok and '(no error)' or err_actual) + arguments[2] = pattern + arguments.nofmt[1] = ok + arguments.nofmt[2] = false + set_failure_message(state, failure_message) + + if ok then return not ok, retargs end + if err_actual == nil and pattern == nil then + return true, {} + end + + -- err_actual must be (convertible to) a string + if util.hastostring(err_actual) then + err_actual = tostring(err_actual) + retargs[1] = err_actual + end + if type(err_actual) == 'string' then + local ok + local retargs_ok + if plain then + retargs_ok = { pattern } + ok = (err_actual:find(pattern, init, plain) ~= nil) + else + retargs_ok = { err_actual:match(pattern, init) } + ok = (retargs_ok[1] ~= nil) + end + if ok then retargs = retargs_ok end + return ok, retargs + end + + return false, retargs +end + +local function is_true(state, arguments, level) + util.tinsert(arguments, 2, true) + set_failure_message(state, arguments[3]) + return arguments[1] == arguments[2] +end + +local function is_false(state, arguments, level) + util.tinsert(arguments, 2, false) + set_failure_message(state, arguments[3]) + return arguments[1] == arguments[2] +end + +local function is_type(state, arguments, level, etype) + util.tinsert(arguments, 2, "type " .. etype) + arguments.nofmt = arguments.nofmt or {} + arguments.nofmt[2] = true + set_failure_message(state, arguments[3]) + return arguments.n > 1 and type(arguments[1]) == etype +end + +local function returned_arguments(state, arguments, level) + arguments[1] = tostring(arguments[1]) + arguments[2] = tostring(arguments.n - 1) + arguments.nofmt = arguments.nofmt or {} + arguments.nofmt[1] = true + arguments.nofmt[2] = true + if arguments.n < 2 then arguments.n = 2 end + return arguments[1] == arguments[2] +end + +local function set_message(state, arguments, level) + state.failure_message = arguments[1] +end + +local function is_boolean(state, arguments, level) return is_type(state, arguments, level, "boolean") end +local function is_number(state, arguments, level) return is_type(state, arguments, level, "number") end +local function is_string(state, arguments, level) return is_type(state, arguments, level, "string") end +local function is_table(state, arguments, level) return is_type(state, arguments, level, "table") end +local function is_nil(state, arguments, level) return is_type(state, arguments, level, "nil") end +local function is_userdata(state, arguments, level) return is_type(state, arguments, level, "userdata") end +local function is_function(state, arguments, level) return is_type(state, arguments, level, "function") end +local function is_thread(state, arguments, level) return is_type(state, arguments, level, "thread") end + +assert:register("modifier", "message", set_message) +assert:register("assertion", "true", is_true, "assertion.same.positive", "assertion.same.negative") +assert:register("assertion", "false", is_false, "assertion.same.positive", "assertion.same.negative") +assert:register("assertion", "boolean", is_boolean, "assertion.same.positive", "assertion.same.negative") +assert:register("assertion", "number", is_number, "assertion.same.positive", "assertion.same.negative") +assert:register("assertion", "string", is_string, "assertion.same.positive", "assertion.same.negative") +assert:register("assertion", "table", is_table, "assertion.same.positive", "assertion.same.negative") +assert:register("assertion", "nil", is_nil, "assertion.same.positive", "assertion.same.negative") +assert:register("assertion", "userdata", is_userdata, "assertion.same.positive", "assertion.same.negative") +assert:register("assertion", "function", is_function, "assertion.same.positive", "assertion.same.negative") +assert:register("assertion", "thread", is_thread, "assertion.same.positive", "assertion.same.negative") +assert:register("assertion", "returned_arguments", returned_arguments, "assertion.returned_arguments.positive", "assertion.returned_arguments.negative") + +assert:register("assertion", "same", same, "assertion.same.positive", "assertion.same.negative") +assert:register("assertion", "matches", matches, "assertion.matches.positive", "assertion.matches.negative") +assert:register("assertion", "match", matches, "assertion.matches.positive", "assertion.matches.negative") +assert:register("assertion", "near", near, "assertion.near.positive", "assertion.near.negative") +assert:register("assertion", "equals", equals, "assertion.equals.positive", "assertion.equals.negative") +assert:register("assertion", "equal", equals, "assertion.equals.positive", "assertion.equals.negative") +assert:register("assertion", "unique", unique, "assertion.unique.positive", "assertion.unique.negative") +assert:register("assertion", "error", has_error, "assertion.error.positive", "assertion.error.negative") +assert:register("assertion", "errors", has_error, "assertion.error.positive", "assertion.error.negative") +assert:register("assertion", "error_matches", error_matches, "assertion.error.positive", "assertion.error.negative") +assert:register("assertion", "error_match", error_matches, "assertion.error.positive", "assertion.error.negative") +assert:register("assertion", "matches_error", error_matches, "assertion.error.positive", "assertion.error.negative") +assert:register("assertion", "match_error", error_matches, "assertion.error.positive", "assertion.error.negative") +assert:register("assertion", "truthy", truthy, "assertion.truthy.positive", "assertion.truthy.negative") +assert:register("assertion", "falsy", falsy, "assertion.falsy.positive", "assertion.falsy.negative") diff --git a/framework/lualib/thirdparty/luassert/compatibility.lua b/framework/lualib/thirdparty/luassert/compatibility.lua new file mode 100755 index 0000000..88290ad --- /dev/null +++ b/framework/lualib/thirdparty/luassert/compatibility.lua @@ -0,0 +1,9 @@ +-- no longer needed, only for backward compatibility +local unpack = require ("luassert.util").unpack + +return { + unpack = function(...) + print(debug.traceback("WARN: calling deprecated function 'luassert.compatibility.unpack' use 'luassert.util.unpack' instead")) + return unpack(...) + end +} diff --git a/framework/lualib/thirdparty/luassert/formatters/binarystring.lua b/framework/lualib/thirdparty/luassert/formatters/binarystring.lua new file mode 100755 index 0000000..a1c5c88 --- /dev/null +++ b/framework/lualib/thirdparty/luassert/formatters/binarystring.lua @@ -0,0 +1,33 @@ +local format = function(str) + if type(str) ~= "string" then + return nil + end + local result = "Binary string length; " .. tostring(#str) .. " bytes\n" + local i = 1 + local hex = "" + local chr = "" + while i <= #str do + local byte = str:byte(i) + hex = string.format("%s%2x ", hex, byte) + if byte < 32 then + byte = string.byte(".") + end + chr = chr .. string.char(byte) + if math.floor(i / 16) == i / 16 or i == #str then + -- reached end of line + hex = hex .. string.rep(" ", 16 * 3 - #hex) + chr = chr .. string.rep(" ", 16 - #chr) + + result = result .. hex:sub(1, 8 * 3) .. " " .. hex:sub(8 * 3 + 1, -1) .. " " .. chr:sub(1, 8) .. " " .. + chr:sub(9, -1) .. "\n" + + hex = "" + chr = "" + end + i = i + 1 + end + return result +end + +return format + diff --git a/framework/lualib/thirdparty/luassert/formatters/init.lua b/framework/lualib/thirdparty/luassert/formatters/init.lua new file mode 100755 index 0000000..718a92e --- /dev/null +++ b/framework/lualib/thirdparty/luassert/formatters/init.lua @@ -0,0 +1,258 @@ +-- module will not return anything, only register formatters with the main assert engine +local assert = require('luassert.assert') +local match = require('luassert.match') +local util = require('luassert.util') + +local colors = setmetatable({ + none = function(c) + return c + end, +}, { + __index = function(self, key) + local ok, term = pcall(require, 'term') + local isatty = io.type(io.stdout) == 'file' and ok and term.isatty(io.stdout) + if not ok or not isatty or not term.colors then + return function(c) + return c + end + end + return function(c) + for token in key:gmatch("[^%.]+") do + c = term.colors[token](c) + end + return c + end + end, +}) + +local function fmt_string(arg) + if type(arg) == "string" then + return string.format("(string) '%s'", arg) + end +end + +-- A version of tostring which formats numbers more precisely. +local function tostr(arg) + if type(arg) ~= "number" then + return tostring(arg) + end + + if arg ~= arg then + return "NaN" + elseif arg == 1 / 0 then + return "Inf" + elseif arg == -1 / 0 then + return "-Inf" + end + + local str = string.format("%.20g", arg) + + if math.type and math.type(arg) == "float" and not str:find("[%.,]") then + -- Number is a float but looks like an integer. + -- Insert ".0" after first run of digits. + str = str:gsub("%d+", "%0.0", 1) + end + + return str +end + +local function fmt_number(arg) + if type(arg) == "number" then + return string.format("(number) %s", tostr(arg)) + end +end + +local function fmt_boolean(arg) + if type(arg) == "boolean" then + return string.format("(boolean) %s", tostring(arg)) + end +end + +local function fmt_nil(arg) + if type(arg) == "nil" then + return "(nil)" + end +end + +local type_priorities = { + number = 1, + boolean = 2, + string = 3, + table = 4, + ["function"] = 5, + userdata = 6, + thread = 7, +} + +local function is_in_array_part(key, length) + return type(key) == "number" and 1 <= key and key <= length and math.floor(key) == key +end + +local function get_sorted_keys(t) + local keys = {} + local nkeys = 0 + + for key in pairs(t) do + nkeys = nkeys + 1 + keys[nkeys] = key + end + + local length = #t + + local function key_comparator(key1, key2) + local type1, type2 = type(key1), type(key2) + local priority1 = is_in_array_part(key1, length) and 0 or type_priorities[type1] or 8 + local priority2 = is_in_array_part(key2, length) and 0 or type_priorities[type2] or 8 + + if priority1 == priority2 then + if type1 == "string" or type1 == "number" then + return key1 < key2 + elseif type1 == "boolean" then + return key1 -- put true before false + end + else + return priority1 < priority2 + end + end + + table.sort(keys, key_comparator) + return keys, nkeys +end + +local function fmt_table(arg, fmtargs) + if type(arg) ~= "table" then + return + end + + local tmax = assert:get_parameter("TableFormatLevel") + local showrec = assert:get_parameter("TableFormatShowRecursion") + local errchar = assert:get_parameter("TableErrorHighlightCharacter") or "" + local errcolor = assert:get_parameter("TableErrorHighlightColor") or "none" + local crumbs = fmtargs and fmtargs.crumbs or {} + local cache = {} + local type_desc + + if getmetatable(arg) == nil then + type_desc = "(" .. tostring(arg) .. ") " + elseif not pcall(setmetatable, arg, getmetatable(arg)) then + -- cannot set same metatable, so it is protected, skip id + type_desc = "(table) " + else + -- unprotected metatable, temporary remove the mt + local mt = getmetatable(arg) + setmetatable(arg, nil) + type_desc = "(" .. tostring(arg) .. ") " + setmetatable(arg, mt) + end + + local function ft(t, l, with_crumbs) + if showrec and cache[t] and cache[t] > 0 then + return "{ ... recursive }" + end + + if next(t) == nil then + return "{ }" + end + + if l > tmax and tmax >= 0 then + return "{ ... more }" + end + + local result = "{" + local keys, nkeys = get_sorted_keys(t) + + cache[t] = (cache[t] or 0) + 1 + local crumb = crumbs[#crumbs - l + 1] + + for i = 1, nkeys do + local k = keys[i] + local v = t[k] + local use_crumbs = with_crumbs and k == crumb + + if type(v) == "table" then + v = ft(v, l + 1, use_crumbs) + elseif type(v) == "string" then + v = "'" .. v .. "'" + end + + local ch = use_crumbs and errchar or "" + local indent = string.rep(" ", l * 2 - ch:len()) + local mark = (ch:len() == 0 and "" or colors[errcolor](ch)) + result = result .. string.format("\n%s%s[%s] = %s", indent, mark, tostr(k), tostr(v)) + end + + cache[t] = cache[t] - 1 + + return result .. " }" + end + + return type_desc .. ft(arg, 1, true) +end + +local function fmt_function(arg) + if type(arg) == "function" then + local debug_info = debug.getinfo(arg) + return string.format("%s @ line %s in %s", tostring(arg), tostring(debug_info.linedefined), + tostring(debug_info.source)) + end +end + +local function fmt_userdata(arg) + if type(arg) == "userdata" then + return string.format("(userdata) '%s'", tostring(arg)) + end +end + +local function fmt_thread(arg) + if type(arg) == "thread" then + return string.format("(thread) '%s'", tostring(arg)) + end +end + +local function fmt_matcher(arg) + if not match.is_matcher(arg) then + return + end + local not_inverted = { + [true] = "is.", + [false] = "no.", + } + local args = {} + for idx = 1, arg.arguments.n do + table.insert(args, assert:format({ + arg.arguments[idx], + n = 1, + })[1]) + end + return string.format("(matcher) %s%s(%s)", not_inverted[arg.mod], tostring(arg.name), table.concat(args, ", ")) +end + +local function fmt_arglist(arglist) + if not util.is_arglist(arglist) then + return + end + local formatted_vals = {} + for idx = 1, arglist.n do + table.insert(formatted_vals, assert:format({ + arglist[idx], + n = 1, + })[1]) + end + return "(values list) (" .. table.concat(formatted_vals, ", ") .. ")" +end + +assert:add_formatter(fmt_string) +assert:add_formatter(fmt_number) +assert:add_formatter(fmt_boolean) +assert:add_formatter(fmt_nil) +assert:add_formatter(fmt_table) +assert:add_formatter(fmt_function) +assert:add_formatter(fmt_userdata) +assert:add_formatter(fmt_thread) +assert:add_formatter(fmt_matcher) +assert:add_formatter(fmt_arglist) +-- Set default table display depth for table formatter +assert:set_parameter("TableFormatLevel", 3) +assert:set_parameter("TableFormatShowRecursion", false) +assert:set_parameter("TableErrorHighlightCharacter", "*") +assert:set_parameter("TableErrorHighlightColor", "none") diff --git a/framework/lualib/thirdparty/luassert/init.lua b/framework/lualib/thirdparty/luassert/init.lua new file mode 100755 index 0000000..a181942 --- /dev/null +++ b/framework/lualib/thirdparty/luassert/init.lua @@ -0,0 +1,18 @@ +local assert = require('luassert.assert') + +assert._COPYRIGHT = "Copyright (c) 2018 Olivine Labs, LLC." +assert._DESCRIPTION = + "Extends Lua's built-in assertions to provide additional tests and the ability to create your own." +assert._VERSION = "Luassert 1.8.0" + +-- load basic asserts +require('luassert.assertions') +require('luassert.modifiers') +require('luassert.array') +require('luassert.matchers') +require('luassert.formatters') + +-- load default language +require('luassert.languages.en') + +return assert diff --git a/framework/lualib/thirdparty/luassert/languages/en.lua b/framework/lualib/thirdparty/luassert/languages/en.lua new file mode 100755 index 0000000..d58d5c4 --- /dev/null +++ b/framework/lualib/thirdparty/luassert/languages/en.lua @@ -0,0 +1,52 @@ +local s = require('luassert.say') + +s:set_namespace('en') + +s:set("assertion.same.positive", "Expected objects to be the same.\nPassed in:\n%s\nExpected:\n%s") +s:set("assertion.same.negative", "Expected objects to not be the same.\nPassed in:\n%s\nDid not expect:\n%s") + +s:set("assertion.equals.positive", "Expected objects to be equal.\nPassed in:\n%s\nExpected:\n%s") +s:set("assertion.equals.negative", "Expected objects to not be equal.\nPassed in:\n%s\nDid not expect:\n%s") + +s:set("assertion.near.positive", "Expected values to be near.\nPassed in:\n%s\nExpected:\n%s +/- %s") +s:set("assertion.near.negative", "Expected values to not be near.\nPassed in:\n%s\nDid not expect:\n%s +/- %s") + +s:set("assertion.matches.positive", "Expected strings to match.\nPassed in:\n%s\nExpected:\n%s") +s:set("assertion.matches.negative", "Expected strings not to match.\nPassed in:\n%s\nDid not expect:\n%s") + +s:set("assertion.unique.positive", "Expected object to be unique:\n%s") +s:set("assertion.unique.negative", "Expected object to not be unique:\n%s") + +s:set("assertion.error.positive", "Expected a different error.\nCaught:\n%s\nExpected:\n%s") +s:set("assertion.error.negative", "Expected no error, but caught:\n%s") + +s:set("assertion.truthy.positive", "Expected to be truthy, but value was:\n%s") +s:set("assertion.truthy.negative", "Expected to not be truthy, but value was:\n%s") + +s:set("assertion.falsy.positive", "Expected to be falsy, but value was:\n%s") +s:set("assertion.falsy.negative", "Expected to not be falsy, but value was:\n%s") + +s:set("assertion.called.positive", "Expected to be called %s time(s), but was called %s time(s)") +s:set("assertion.called.negative", "Expected not to be called exactly %s time(s), but it was.") + +s:set("assertion.called_at_least.positive", "Expected to be called at least %s time(s), but was called %s time(s)") +s:set("assertion.called_at_most.positive", "Expected to be called at most %s time(s), but was called %s time(s)") +s:set("assertion.called_more_than.positive", "Expected to be called more than %s time(s), but was called %s time(s)") +s:set("assertion.called_less_than.positive", "Expected to be called less than %s time(s), but was called %s time(s)") + +s:set("assertion.called_with.positive", + "Function was never called with matching arguments.\nCalled with (last call if any):\n%s\nExpected:\n%s") +s:set("assertion.called_with.negative", + "Function was called with matching arguments at least once.\nCalled with (last matching call):\n%s\nDid not expect:\n%s") + +s:set("assertion.returned_with.positive", + "Function never returned matching arguments.\nReturned (last call if any):\n%s\nExpected:\n%s") +s:set("assertion.returned_with.negative", + "Function returned matching arguments at least once.\nReturned (last matching call):\n%s\nDid not expect:\n%s") + +s:set("assertion.returned_arguments.positive", "Expected to be called with %s argument(s), but was called with %s") +s:set("assertion.returned_arguments.negative", "Expected not to be called with %s argument(s), but was called with %s") + +-- errors +s:set("assertion.internal.argtolittle", "the '%s' function requires a minimum of %s arguments, got: %s") +s:set("assertion.internal.badargtype", "bad argument #%s to '%s' (%s expected, got %s)") diff --git a/framework/lualib/thirdparty/luassert/languages/zh.lua b/framework/lualib/thirdparty/luassert/languages/zh.lua new file mode 100755 index 0000000..f53ca44 --- /dev/null +++ b/framework/lualib/thirdparty/luassert/languages/zh.lua @@ -0,0 +1,31 @@ +local s = require('luassert.say') + +s:set_namespace('zh') + +s:set("assertion.same.positive", "希望对象应该相同.\n实际值:\n%s\n希望值:\n%s") +s:set("assertion.same.negative", "希望对象应该不相同.\n实际值:\n%s\n不希望与:\n%s\n相同") + +s:set("assertion.equals.positive", "希望对象应该相等.\n实际值:\n%s\n希望值:\n%s") +s:set("assertion.equals.negative", "希望对象应该不相等.\n实际值:\n%s\n不希望等于:\n%s") + +s:set("assertion.unique.positive", "希望对象是唯一的:\n%s") +s:set("assertion.unique.negative", "希望对象不是唯一的:\n%s") + +s:set("assertion.error.positive", "希望有错误被抛出.") +s:set("assertion.error.negative", "希望没有错误被抛出.\n%s") + +s:set("assertion.truthy.positive", "希望结果为真,但是实际为:\n%s") +s:set("assertion.truthy.negative", "希望结果不为真,但是实际为:\n%s") + +s:set("assertion.falsy.positive", "希望结果为假,但是实际为:\n%s") +s:set("assertion.falsy.negative", "希望结果不为假,但是实际为:\n%s") + +s:set("assertion.called.positive", "希望被调用%s次, 但实际被调用了%s次") +s:set("assertion.called.negative", "不希望正好被调用%s次, 但是正好被调用了那么多次.") + +s:set("assertion.called_with.positive", "希望没有参数的调用函数") +s:set("assertion.called_with.negative", "希望有参数的调用函数") + +-- errors +s:set("assertion.internal.argtolittle", "函数'%s'需要最少%s个参数, 实际有%s个参数\n") +s:set("assertion.internal.badargtype", "bad argument #%s: 函数'%s'需要一个%s作为参数, 实际为: %s\n") diff --git a/framework/lualib/thirdparty/luassert/match.lua b/framework/lualib/thirdparty/luassert/match.lua new file mode 100755 index 0000000..671c82f --- /dev/null +++ b/framework/lualib/thirdparty/luassert/match.lua @@ -0,0 +1,79 @@ +local namespace = require 'luassert.namespaces' +local util = require 'luassert.util' + +local matcher_mt = { + __call = function(self, value) + return self.callback(value) == self.mod + end, +} + +local state_mt = { + __call = function(self, ...) + local keys = util.extract_keys("matcher", self.tokens) + self.tokens = {} + + local matcher + + for _, key in ipairs(keys) do + matcher = namespace.matcher[key] or matcher + end + + if matcher then + for _, key in ipairs(keys) do + if namespace.modifier[key] then + namespace.modifier[key].callback(self) + end + end + + local arguments = util.make_arglist(...) + local matches = matcher.callback(self, arguments, util.errorlevel()) + return setmetatable({ + name = matcher.name, + mod = self.mod, + callback = matches, + arguments = arguments, + }, matcher_mt) + else + local arguments = util.make_arglist(...) + + for _, key in ipairs(keys) do + if namespace.modifier[key] then + namespace.modifier[key].callback(self, arguments, util.errorlevel()) + end + end + end + + return self + end, + + __index = function(self, key) + for token in key:lower():gmatch('[^_]+') do + table.insert(self.tokens, token) + end + + return self + end +} + +local match = { + _ = setmetatable({mod=true, callback=function() return true end}, matcher_mt), + + state = function() return setmetatable({mod=true, tokens={}}, state_mt) end, + + is_matcher = function(object) + return type(object) == "table" and getmetatable(object) == matcher_mt + end, + + is_ref_matcher = function(object) + local ismatcher = (type(object) == "table" and getmetatable(object) == matcher_mt) + return ismatcher and object.name == "ref" + end, +} + +local mt = { + __index = function(self, key) + return rawget(self, key) or self.state()[key] + end, +} + +return setmetatable(match, mt) diff --git a/framework/lualib/thirdparty/luassert/matchers/composite.lua b/framework/lualib/thirdparty/luassert/matchers/composite.lua new file mode 100755 index 0000000..58856e4 --- /dev/null +++ b/framework/lualib/thirdparty/luassert/matchers/composite.lua @@ -0,0 +1,64 @@ +local assert = require('luassert.assert') +local match = require('luassert.match') +local s = require('luassert.say') + +local function none(state, arguments, level) + local level = (level or 1) + 1 + local argcnt = arguments.n + assert(argcnt > 0, s("assertion.internal.argtolittle", {"none", 1, tostring(argcnt)}), level) + for i = 1, argcnt do + assert(match.is_matcher(arguments[i]), + s("assertion.internal.badargtype", {1, "none", "matcher", type(arguments[i])}), level) + end + + return function(value) + for _, matcher in ipairs(arguments) do + if matcher(value) then + return false + end + end + return true + end +end + +local function any(state, arguments, level) + local level = (level or 1) + 1 + local argcnt = arguments.n + assert(argcnt > 0, s("assertion.internal.argtolittle", {"any", 1, tostring(argcnt)}), level) + for i = 1, argcnt do + assert(match.is_matcher(arguments[i]), + s("assertion.internal.badargtype", {1, "any", "matcher", type(arguments[i])}), level) + end + + return function(value) + for _, matcher in ipairs(arguments) do + if matcher(value) then + return true + end + end + return false + end +end + +local function all(state, arguments, level) + local level = (level or 1) + 1 + local argcnt = arguments.n + assert(argcnt > 0, s("assertion.internal.argtolittle", {"all", 1, tostring(argcnt)}), level) + for i = 1, argcnt do + assert(match.is_matcher(arguments[i]), + s("assertion.internal.badargtype", {1, "all", "matcher", type(arguments[i])}), level) + end + + return function(value) + for _, matcher in ipairs(arguments) do + if not matcher(value) then + return false + end + end + return true + end +end + +assert:register("matcher", "none_of", none) +assert:register("matcher", "any_of", any) +assert:register("matcher", "all_of", all) diff --git a/framework/lualib/thirdparty/luassert/matchers/core.lua b/framework/lualib/thirdparty/luassert/matchers/core.lua new file mode 100755 index 0000000..66a3fb0 --- /dev/null +++ b/framework/lualib/thirdparty/luassert/matchers/core.lua @@ -0,0 +1,173 @@ +-- module will return the list of matchers, and registers matchers with the main assert engine + +-- matchers take 1 parameters; +-- 1) state +-- 2) arguments list. The list has a member 'n' with the argument count to check for trailing nils +-- 3) level The level of the error position relative to the called function +-- returns; function (or callable object); a function that, given an argument, returns a boolean + +local assert = require('luassert.assert') +local astate = require('luassert.state') +local util = require('luassert.util') +local s = require('luassert.say') + +local function format(val) + return astate.format_argument(val) or tostring(val) +end + +local function unique(state, arguments, level) + local deep = arguments[1] + return function(value) + local list = value + for k,v in pairs(list) do + for k2, v2 in pairs(list) do + if k ~= k2 then + if deep and util.deepcompare(v, v2, true) then + return false + else + if v == v2 then + return false + end + end + end + end + end + return true + end +end + +local function near(state, arguments, level) + local level = (level or 1) + 1 + local argcnt = arguments.n + assert(argcnt > 1, s("assertion.internal.argtolittle", { "near", 2, tostring(argcnt) }), level) + local expected = tonumber(arguments[1]) + local tolerance = tonumber(arguments[2]) + local numbertype = "number or object convertible to a number" + assert(expected, s("assertion.internal.badargtype", { 1, "near", numbertype, format(arguments[1]) }), level) + assert(tolerance, s("assertion.internal.badargtype", { 2, "near", numbertype, format(arguments[2]) }), level) + + return function(value) + local actual = tonumber(value) + if not actual then return false end + return (actual >= expected - tolerance and actual <= expected + tolerance) + end +end + +local function matches(state, arguments, level) + local level = (level or 1) + 1 + local argcnt = arguments.n + assert(argcnt > 0, s("assertion.internal.argtolittle", { "matches", 1, tostring(argcnt) }), level) + local pattern = arguments[1] + local init = arguments[2] + local plain = arguments[3] + assert(type(pattern) == "string", s("assertion.internal.badargtype", { 1, "matches", "string", type(arguments[1]) }), level) + assert(init == nil or tonumber(init), s("assertion.internal.badargtype", { 2, "matches", "number", type(arguments[2]) }), level) + + return function(value) + local actualtype = type(value) + local actual = nil + if actualtype == "string" or actualtype == "number" or + actualtype == "table" and (getmetatable(value) or {}).__tostring then + actual = tostring(value) + end + if not actual then return false end + return (actual:find(pattern, init, plain) ~= nil) + end +end + +local function equals(state, arguments, level) + local level = (level or 1) + 1 + local argcnt = arguments.n + assert(argcnt > 0, s("assertion.internal.argtolittle", { "equals", 1, tostring(argcnt) }), level) + return function(value) + return value == arguments[1] + end +end + +local function same(state, arguments, level) + local level = (level or 1) + 1 + local argcnt = arguments.n + assert(argcnt > 0, s("assertion.internal.argtolittle", { "same", 1, tostring(argcnt) }), level) + return function(value) + if type(value) == 'table' and type(arguments[1]) == 'table' then + local result = util.deepcompare(value, arguments[1], true) + return result + end + return value == arguments[1] + end +end + +local function ref(state, arguments, level) + local level = (level or 1) + 1 + local argcnt = arguments.n + local argtype = type(arguments[1]) + local isobject = (argtype == "table" or argtype == "function" or argtype == "thread" or argtype == "userdata") + assert(argcnt > 0, s("assertion.internal.argtolittle", { "ref", 1, tostring(argcnt) }), level) + assert(isobject, s("assertion.internal.badargtype", { 1, "ref", "object", argtype }), level) + return function(value) + return value == arguments[1] + end +end + +local function is_true(state, arguments, level) + return function(value) + return value == true + end +end + +local function is_false(state, arguments, level) + return function(value) + return value == false + end +end + +local function truthy(state, arguments, level) + return function(value) + return value ~= false and value ~= nil + end +end + +local function falsy(state, arguments, level) + local is_truthy = truthy(state, arguments, level) + return function(value) + return not is_truthy(value) + end +end + +local function is_type(state, arguments, level, etype) + return function(value) + return type(value) == etype + end +end + +local function is_nil(state, arguments, level) return is_type(state, arguments, level, "nil") end +local function is_boolean(state, arguments, level) return is_type(state, arguments, level, "boolean") end +local function is_number(state, arguments, level) return is_type(state, arguments, level, "number") end +local function is_string(state, arguments, level) return is_type(state, arguments, level, "string") end +local function is_table(state, arguments, level) return is_type(state, arguments, level, "table") end +local function is_function(state, arguments, level) return is_type(state, arguments, level, "function") end +local function is_userdata(state, arguments, level) return is_type(state, arguments, level, "userdata") end +local function is_thread(state, arguments, level) return is_type(state, arguments, level, "thread") end + +assert:register("matcher", "true", is_true) +assert:register("matcher", "false", is_false) + +assert:register("matcher", "nil", is_nil) +assert:register("matcher", "boolean", is_boolean) +assert:register("matcher", "number", is_number) +assert:register("matcher", "string", is_string) +assert:register("matcher", "table", is_table) +assert:register("matcher", "function", is_function) +assert:register("matcher", "userdata", is_userdata) +assert:register("matcher", "thread", is_thread) + +assert:register("matcher", "ref", ref) +assert:register("matcher", "same", same) +assert:register("matcher", "matches", matches) +assert:register("matcher", "match", matches) +assert:register("matcher", "near", near) +assert:register("matcher", "equals", equals) +assert:register("matcher", "equal", equals) +assert:register("matcher", "unique", unique) +assert:register("matcher", "truthy", truthy) +assert:register("matcher", "falsy", falsy) diff --git a/framework/lualib/thirdparty/luassert/matchers/init.lua b/framework/lualib/thirdparty/luassert/matchers/init.lua new file mode 100755 index 0000000..c0ad62b --- /dev/null +++ b/framework/lualib/thirdparty/luassert/matchers/init.lua @@ -0,0 +1,3 @@ +-- load basic machers +require('luassert.matchers.core') +require('luassert.matchers.composite') diff --git a/framework/lualib/thirdparty/luassert/mock.lua b/framework/lualib/thirdparty/luassert/mock.lua new file mode 100755 index 0000000..273c201 --- /dev/null +++ b/framework/lualib/thirdparty/luassert/mock.lua @@ -0,0 +1,65 @@ +-- module will return a mock module table, and will not register any assertions +local spy = require 'luassert.spy' +local stub = require 'luassert.stub' + +local function mock_apply(object, action) + if type(object) ~= "table" then + return + end + if spy.is_spy(object) then + return object[action](object) + end + for k, v in pairs(object) do + mock_apply(v, action) + end + return object +end + +local mock +mock = { + new = function(object, dostub, func, self, key) + local visited = {} + local function do_mock(object, self, key) + local mock_handlers = { + ["table"] = function() + if spy.is_spy(object) or visited[object] then + return + end + visited[object] = true + for k, v in pairs(object) do + object[k] = do_mock(v, object, k) + end + return object + end, + ["function"] = function() + if dostub then + return stub(self, key, func) + elseif self == nil then + return spy.new(object) + else + return spy.on(self, key) + end + end, + } + local handler = mock_handlers[type(object)] + return handler and handler() or object + end + return do_mock(object, self, key) + end, + + clear = function(object) + return mock_apply(object, "clear") + end, + + revert = function(object) + return mock_apply(object, "revert") + end, +} + +return setmetatable(mock, { + __call = function(self, ...) + -- mock originally was a function only. Now that it is a module table + -- the __call method is required for backward compatibility + return mock.new(...) + end, +}) diff --git a/framework/lualib/thirdparty/luassert/modifiers.lua b/framework/lualib/thirdparty/luassert/modifiers.lua new file mode 100755 index 0000000..58ee5dc --- /dev/null +++ b/framework/lualib/thirdparty/luassert/modifiers.lua @@ -0,0 +1,19 @@ +-- module will not return anything, only register assertions/modifiers with the main assert engine +local assert = require('luassert.assert') + +local function is(state) + return state +end + +local function is_not(state) + state.mod = not state.mod + return state +end + +assert:register("modifier", "is", is) +assert:register("modifier", "are", is) +assert:register("modifier", "was", is) +assert:register("modifier", "has", is) +assert:register("modifier", "does", is) +assert:register("modifier", "not", is_not) +assert:register("modifier", "no", is_not) diff --git a/framework/lualib/thirdparty/luassert/namespaces.lua b/framework/lualib/thirdparty/luassert/namespaces.lua new file mode 100755 index 0000000..0790fce --- /dev/null +++ b/framework/lualib/thirdparty/luassert/namespaces.lua @@ -0,0 +1,2 @@ +-- stores the list of namespaces +return {} diff --git a/framework/lualib/thirdparty/luassert/say.lua b/framework/lualib/thirdparty/luassert/say.lua new file mode 100644 index 0000000..1fd74d8 --- /dev/null +++ b/framework/lualib/thirdparty/luassert/say.lua @@ -0,0 +1,64 @@ +local unpack = table.unpack or unpack + +local registry = {} +local current_namespace +local fallback_namespace + +local s = { + + _COPYRIGHT = "Copyright (c) 2012 Olivine Labs, LLC.", + _DESCRIPTION = "A simple string key/value store for i18n or any other case where you want namespaced strings.", + _VERSION = "Say 1.3", + + set_namespace = function(self, namespace) + current_namespace = namespace + if not registry[current_namespace] then + registry[current_namespace] = {} + end + end, + + set_fallback = function(self, namespace) + fallback_namespace = namespace + if not registry[fallback_namespace] then + registry[fallback_namespace] = {} + end + end, + + set = function(self, key, value) + registry[current_namespace][key] = value + end, +} + +local __meta = { + __call = function(self, key, vars) + if vars ~= nil and type(vars) ~= "table" then + error(("expected parameter table to be a table, got '%s'"):format(type(vars)), 2) + end + vars = vars or {} + + local str = registry[current_namespace][key] or registry[fallback_namespace][key] + + if str == nil then + return nil + end + str = tostring(str) + local strings = {} + + for i = 1, vars.n or #vars do + table.insert(strings, tostring(vars[i])) + end + + return #strings > 0 and str:format(unpack(strings)) or str + end, + + __index = function(self, key) + return registry[key] + end, +} + +s:set_fallback('en') +s:set_namespace('en') + +s._registry = registry + +return setmetatable(s, __meta) diff --git a/framework/lualib/thirdparty/luassert/spy.lua b/framework/lualib/thirdparty/luassert/spy.lua new file mode 100755 index 0000000..51f60b8 --- /dev/null +++ b/framework/lualib/thirdparty/luassert/spy.lua @@ -0,0 +1,215 @@ +-- module will return spy table, and register its assertions with the main assert engine +local assert = require('luassert.assert') +local util = require('luassert.util') + +-- Spy metatable +local spy_mt = { + __call = function(self, ...) + local arguments = util.make_arglist(...) + table.insert(self.calls, util.copyargs(arguments)) + local function get_returns(...) + local returnvals = util.make_arglist(...) + table.insert(self.returnvals, util.copyargs(returnvals)) + return ... + end + return get_returns(self.callback(...)) + end, +} + +local spy -- must make local before defining table, because table contents refers to the table (recursion) +spy = { + new = function(callback) + callback = callback or function() + end + if not util.callable(callback) then + error("Cannot spy on type '" .. type(callback) .. "', only on functions or callable elements", + util.errorlevel()) + end + local s = setmetatable({ + calls = {}, + returnvals = {}, + callback = callback, + + target_table = nil, -- these will be set when using 'spy.on' + target_key = nil, + + revert = function(self) + if not self.reverted then + if self.target_table and self.target_key then + self.target_table[self.target_key] = self.callback + end + self.reverted = true + end + return self.callback + end, + + clear = function(self) + self.calls = {} + self.returnvals = {} + return self + end, + + called = function(self, times, compare) + if times or compare then + local compare = compare or function(count, expected) + return count == expected + end + return compare(#self.calls, times), #self.calls + end + + return (#self.calls > 0), #self.calls + end, + + called_with = function(self, args) + local last_arglist = nil + if #self.calls > 0 then + last_arglist = self.calls[#self.calls].vals + end + local matching_arglists = util.matchargs(self.calls, args) + if matching_arglists ~= nil then + return true, matching_arglists.vals + end + return false, last_arglist + end, + + returned_with = function(self, args) + local last_returnvallist = nil + if #self.returnvals > 0 then + last_returnvallist = self.returnvals[#self.returnvals].vals + end + local matching_returnvallists = util.matchargs(self.returnvals, args) + if matching_returnvallists ~= nil then + return true, matching_returnvallists.vals + end + return false, last_returnvallist + end, + }, spy_mt) + assert:add_spy(s) -- register with the current state + return s + end, + + is_spy = function(object) + return type(object) == "table" and getmetatable(object) == spy_mt + end, + + on = function(target_table, target_key) + local s = spy.new(target_table[target_key]) + target_table[target_key] = s + -- store original data + s.target_table = target_table + s.target_key = target_key + + return s + end, +} + +local function set_spy(state, arguments, level) + state.payload = arguments[1] + if arguments[2] ~= nil then + state.failure_message = arguments[2] + end +end + +local function returned_with(state, arguments, level) + local level = (level or 1) + 1 + local payload = rawget(state, "payload") + if payload and payload.returned_with then + local assertion_holds, matching_or_last_returnvallist = state.payload:returned_with(arguments) + local expected_returnvallist = util.shallowcopy(arguments) + util.cleararglist(arguments) + util.tinsert(arguments, 1, matching_or_last_returnvallist) + util.tinsert(arguments, 2, expected_returnvallist) + return assertion_holds + else + error("'returned_with' must be chained after 'spy(aspy)'", level) + end +end + +local function called_with(state, arguments, level) + local level = (level or 1) + 1 + local payload = rawget(state, "payload") + if payload and payload.called_with then + local assertion_holds, matching_or_last_arglist = state.payload:called_with(arguments) + local expected_arglist = util.shallowcopy(arguments) + util.cleararglist(arguments) + util.tinsert(arguments, 1, matching_or_last_arglist) + util.tinsert(arguments, 2, expected_arglist) + return assertion_holds + else + error("'called_with' must be chained after 'spy(aspy)'", level) + end +end + +local function called(state, arguments, level, compare) + local level = (level or 1) + 1 + local num_times = arguments[1] + if not num_times and not state.mod then + state.mod = true + num_times = 0 + end + local payload = rawget(state, "payload") + if payload and type(payload) == "table" and payload.called then + local result, count = state.payload:called(num_times, compare) + arguments[1] = tostring(num_times or ">0") + util.tinsert(arguments, 2, tostring(count)) + arguments.nofmt = arguments.nofmt or {} + arguments.nofmt[1] = true + arguments.nofmt[2] = true + return result + elseif payload and type(payload) == "function" then + error( + "When calling 'spy(aspy)', 'aspy' must not be the original function, but the spy function replacing the original", + level) + else + error("'called' must be chained after 'spy(aspy)'", level) + end +end + +local function called_at_least(state, arguments, level) + local level = (level or 1) + 1 + return called(state, arguments, level, function(count, expected) + return count >= expected + end) +end + +local function called_at_most(state, arguments, level) + local level = (level or 1) + 1 + return called(state, arguments, level, function(count, expected) + return count <= expected + end) +end + +local function called_more_than(state, arguments, level) + local level = (level or 1) + 1 + return called(state, arguments, level, function(count, expected) + return count > expected + end) +end + +local function called_less_than(state, arguments, level) + local level = (level or 1) + 1 + return called(state, arguments, level, function(count, expected) + return count < expected + end) +end + +assert:register("modifier", "spy", set_spy) +assert:register("assertion", "returned_with", returned_with, "assertion.returned_with.positive", + "assertion.returned_with.negative") +assert:register("assertion", "called_with", called_with, "assertion.called_with.positive", + "assertion.called_with.negative") +assert:register("assertion", "called", called, "assertion.called.positive", "assertion.called.negative") +assert:register("assertion", "called_at_least", called_at_least, "assertion.called_at_least.positive", + "assertion.called_less_than.positive") +assert:register("assertion", "called_at_most", called_at_most, "assertion.called_at_most.positive", + "assertion.called_more_than.positive") +assert:register("assertion", "called_more_than", called_more_than, "assertion.called_more_than.positive", + "assertion.called_at_most.positive") +assert:register("assertion", "called_less_than", called_less_than, "assertion.called_less_than.positive", + "assertion.called_at_least.positive") + +return setmetatable(spy, { + __call = function(self, ...) + return spy.new(...) + end, +}) diff --git a/framework/lualib/thirdparty/luassert/state.lua b/framework/lualib/thirdparty/luassert/state.lua new file mode 100755 index 0000000..4210727 --- /dev/null +++ b/framework/lualib/thirdparty/luassert/state.lua @@ -0,0 +1,134 @@ +-- maintains a state of the assert engine in a linked-list fashion +-- records; formatters, parameters, spies and stubs +local state_mt = { + __call = function(self) + self:revert() + end, +} + +local spies_mt = { + __mode = "kv", +} + +local nilvalue = {} -- unique ID to refer to nil values for parameters + +-- will hold the current state +local current + +-- exported module table +local state = {} + +------------------------------------------------------ +-- Reverts to a (specific) snapshot. +-- @param self (optional) the snapshot to revert to. If not provided, it will revert to the last snapshot. +state.revert = function(self) + if not self then + -- no snapshot given, so move 1 up + self = current + if not self.previous then + -- top of list, no previous one, nothing to do + return + end + end + if getmetatable(self) ~= state_mt then + error("Value provided is not a valid snapshot", 2) + end + + if self.next then + self.next:revert() + end + -- revert formatters in 'last' + self.formatters = {} + -- revert parameters in 'last' + self.parameters = {} + -- revert spies/stubs in 'last' + for s, _ in pairs(self.spies) do + self.spies[s] = nil + s:revert() + end + setmetatable(self, nil) -- invalidate as a snapshot + current = self.previous + current.next = nil +end + +------------------------------------------------------ +-- Creates a new snapshot. +-- @return snapshot table +state.snapshot = function() + local new = setmetatable({ + formatters = {}, + parameters = {}, + spies = setmetatable({}, spies_mt), + previous = current, + revert = state.revert, + }, state_mt) + if current then + current.next = new + end + current = new + return current +end + +-- FORMATTERS +state.add_formatter = function(callback) + table.insert(current.formatters, 1, callback) +end + +state.remove_formatter = function(callback, s) + s = s or current + for i, v in ipairs(s.formatters) do + if v == callback then + table.remove(s.formatters, i) + break + end + end + -- wasn't found, so traverse up 1 state + if s.previous then + state.remove_formatter(callback, s.previous) + end +end + +state.format_argument = function(val, s, fmtargs) + s = s or current + for _, fmt in ipairs(s.formatters) do + local valfmt = fmt(val, fmtargs) + if valfmt ~= nil then + return valfmt + end + end + -- nothing found, check snapshot 1 up in list + if s.previous then + return state.format_argument(val, s.previous, fmtargs) + end + return nil -- end of list, couldn't format +end + +-- PARAMETERS +state.set_parameter = function(name, value) + if value == nil then + value = nilvalue + end + current.parameters[name] = value +end + +state.get_parameter = function(name, s) + s = s or current + local val = s.parameters[name] + if val == nil and s.previous then + -- not found, so check 1 up in list + return state.get_parameter(name, s.previous) + end + if val ~= nilvalue then + return val + end + return nil +end + +-- SPIES / STUBS +state.add_spy = function(spy) + current.spies[spy] = true +end + +state.snapshot() -- create initial state + +return state diff --git a/framework/lualib/thirdparty/luassert/stub.lua b/framework/lualib/thirdparty/luassert/stub.lua new file mode 100755 index 0000000..c2c9a5f --- /dev/null +++ b/framework/lualib/thirdparty/luassert/stub.lua @@ -0,0 +1,109 @@ +-- module will return a stub module table +local assert = require 'luassert.assert' +local spy = require 'luassert.spy' +local util = require 'luassert.util' +local unpack = util.unpack +local pack = util.pack + +local stub = {} + +function stub.new(object, key, ...) + if object == nil and key == nil then + -- called without arguments, create a 'blank' stub + object = {} + key = "" + end + local return_values = pack(...) + assert(type(object) == "table" and key ~= nil, + "stub.new(): Can only create stub on a table key, call with 2 params; table, key", util.errorlevel()) + assert(object[key] == nil or util.callable(object[key]), + "stub.new(): The element for which to create a stub must either be callable, or be nil", util.errorlevel()) + local old_elem = object[key] -- keep existing element (might be nil!) + + local fn = (return_values.n == 1 and util.callable(return_values[1]) and return_values[1]) + local defaultfunc = fn or function() + return unpack(return_values) + end + local oncalls = {} + local callbacks = {} + local stubfunc = function(...) + local args = util.make_arglist(...) + local match = util.matchoncalls(oncalls, args) + if match then + return callbacks[match](...) + end + return defaultfunc(...) + end + + object[key] = stubfunc -- set the stubfunction + local s = spy.on(object, key) -- create a spy on top of the stub function + local spy_revert = s.revert -- keep created revert function + + s.revert = function(self) -- wrap revert function to restore original element + if not self.reverted then + spy_revert(self) + object[key] = old_elem + self.reverted = true + end + return old_elem + end + + s.returns = function(...) + local return_args = pack(...) + defaultfunc = function() + return unpack(return_args) + end + return s + end + + s.invokes = function(func) + defaultfunc = function(...) + return func(...) + end + return s + end + + s.by_default = { + returns = s.returns, + invokes = s.invokes, + } + + s.on_call_with = function(...) + local match_args = util.make_arglist(...) + match_args = util.copyargs(match_args) + return { + returns = function(...) + local return_args = pack(...) + table.insert(oncalls, match_args) + callbacks[match_args] = function() + return unpack(return_args) + end + return s + end, + invokes = function(func) + table.insert(oncalls, match_args) + callbacks[match_args] = function(...) + return func(...) + end + return s + end, + } + end + + return s +end + +local function set_stub(state, arguments) + state.payload = arguments[1] + state.failure_message = arguments[2] +end + +assert:register("modifier", "stub", set_stub) + +return setmetatable(stub, { + __call = function(self, ...) + -- stub originally was a function only. Now that it is a module table + -- the __call method is required for backward compatibility + return stub.new(...) + end, +}) diff --git a/framework/lualib/thirdparty/luassert/util.lua b/framework/lualib/thirdparty/luassert/util.lua new file mode 100755 index 0000000..5658899 --- /dev/null +++ b/framework/lualib/thirdparty/luassert/util.lua @@ -0,0 +1,386 @@ +local util = {} +local arglist_mt = {} + +-- have pack/unpack both respect the 'n' field +local _unpack = table.unpack or unpack +local unpack = function(t, i, j) + return _unpack(t, i or 1, j or t.n or #t) +end +local pack = function(...) + return { + n = select("#", ...), + ..., + } +end +util.pack = pack +util.unpack = unpack + +function util.deepcompare(t1, t2, ignore_mt, cycles, thresh1, thresh2) + local ty1 = type(t1) + local ty2 = type(t2) + -- non-table types can be directly compared + if ty1 ~= 'table' or ty2 ~= 'table' then + return t1 == t2 + end + local mt1 = debug.getmetatable(t1) + local mt2 = debug.getmetatable(t2) + -- would equality be determined by metatable __eq? + if mt1 and mt1 == mt2 and mt1.__eq then + -- then use that unless asked not to + if not ignore_mt then + return t1 == t2 + end + else -- we can skip the deep comparison below if t1 and t2 share identity + if rawequal(t1, t2) then + return true + end + end + + -- handle recursive tables + cycles = cycles or {{}, {}} + thresh1, thresh2 = (thresh1 or 1), (thresh2 or 1) + cycles[1][t1] = (cycles[1][t1] or 0) + cycles[2][t2] = (cycles[2][t2] or 0) + if cycles[1][t1] == 1 or cycles[2][t2] == 1 then + thresh1 = cycles[1][t1] + 1 + thresh2 = cycles[2][t2] + 1 + end + if cycles[1][t1] > thresh1 and cycles[2][t2] > thresh2 then + return true + end + + cycles[1][t1] = cycles[1][t1] + 1 + cycles[2][t2] = cycles[2][t2] + 1 + + for k1, v1 in next, t1 do + local v2 = t2[k1] + if v2 == nil then + return false, {k1} + end + + local same, crumbs = util.deepcompare(v1, v2, nil, cycles, thresh1, thresh2) + if not same then + crumbs = crumbs or {} + table.insert(crumbs, k1) + return false, crumbs + end + end + for k2, _ in next, t2 do + -- only check whether each element has a t1 counterpart, actual comparison + -- has been done in first loop above + if t1[k2] == nil then + return false, {k2} + end + end + + cycles[1][t1] = cycles[1][t1] - 1 + cycles[2][t2] = cycles[2][t2] - 1 + + return true +end + +function util.shallowcopy(t) + if type(t) ~= "table" then + return t + end + local copy = {} + setmetatable(copy, getmetatable(t)) + for k, v in next, t do + copy[k] = v + end + return copy +end + +function util.deepcopy(t, deepmt, cache) + local spy = require 'luassert.spy' + if type(t) ~= "table" then + return t + end + local copy = {} + + -- handle recursive tables + local cache = cache or {} + if cache[t] then + return cache[t] + end + cache[t] = copy + + for k, v in next, t do + copy[k] = (spy.is_spy(v) and v or util.deepcopy(v, deepmt, cache)) + end + if deepmt then + debug.setmetatable(copy, util.deepcopy(debug.getmetatable(t, nil, cache))) + else + debug.setmetatable(copy, debug.getmetatable(t)) + end + return copy +end + +----------------------------------------------- +-- Copies arguments as a list of arguments +-- @param args the arguments of which to copy +-- @return the copy of the arguments +function util.copyargs(args) + local copy = {} + setmetatable(copy, getmetatable(args)) + local match = require 'luassert.match' + local spy = require 'luassert.spy' + for k, v in pairs(args) do + copy[k] = ((match.is_matcher(v) or spy.is_spy(v)) and v or util.deepcopy(v)) + end + return { + vals = copy, + refs = util.shallowcopy(args), + } +end + +----------------------------------------------- +-- Clear an arguments or return values list from a table +-- @param arglist the table to clear of arguments or return values and their count +-- @return No return values +function util.cleararglist(arglist) + for idx = arglist.n, 1, -1 do + util.tremove(arglist, idx) + end + arglist.n = nil +end + +----------------------------------------------- +-- Test specs against an arglist in deepcopy and refs flavours. +-- @param args deepcopy arglist +-- @param argsrefs refs arglist +-- @param specs arguments/return values to match against args/argsrefs +-- @return true if specs match args/argsrefs, false otherwise +local function matcharg(args, argrefs, specs) + local match = require 'luassert.match' + for idx, argval in pairs(args) do + local spec = specs[idx] + if match.is_matcher(spec) then + if match.is_ref_matcher(spec) then + argval = argrefs[idx] + end + if not spec(argval) then + return false + end + elseif (spec == nil or not util.deepcompare(argval, spec)) then + return false + end + end + + for idx, spec in pairs(specs) do + -- only check whether each element has an args counterpart, + -- actual comparison has been done in first loop above + local argval = args[idx] + if argval == nil then + -- no args counterpart, so try to compare using matcher + if match.is_matcher(spec) then + if not spec(argval) then + return false + end + else + return false + end + end + end + return true +end + +----------------------------------------------- +-- Find matching arguments/return values in a saved list of +-- arguments/returned values. +-- @param invocations_list list of arguments/returned values to search (list of lists) +-- @param specs arguments/return values to match against argslist +-- @return the last matching arguments/returned values if a match is found, otherwise nil +function util.matchargs(invocations_list, specs) + -- Search the arguments/returned values last to first to give the + -- most helpful answer possible. In the cases where you can place + -- your assertions between calls to check this gives you the best + -- information if no calls match. In the cases where you can't do + -- that there is no good way to predict what would work best. + assert(not util.is_arglist(invocations_list), "expected a list of arglist-object, got an arglist") + for ii = #invocations_list, 1, -1 do + local val = invocations_list[ii] + if matcharg(val.vals, val.refs, specs) then + return val + end + end + return nil +end + +----------------------------------------------- +-- Find matching oncall for an actual call. +-- @param oncalls list of oncalls to search +-- @param args actual call argslist to match against +-- @return the first matching oncall if a match is found, otherwise nil +function util.matchoncalls(oncalls, args) + for _, callspecs in ipairs(oncalls) do + -- This lookup is done immediately on *args* passing into the stub + -- so pass *args* as both *args* and *argsref* without copying + -- either. + if matcharg(args, args, callspecs.vals) then + return callspecs + end + end + return nil +end + +----------------------------------------------- +-- table.insert() replacement that respects nil values. +-- The function will use table field 'n' as indicator of the +-- table length, if not set, it will be added. +-- @param t table into which to insert +-- @param pos (optional) position in table where to insert. NOTE: not optional if you want to insert a nil-value! +-- @param val value to insert +-- @return No return values +function util.tinsert(...) + -- check optional POS value + local args = {...} + local c = select('#', ...) + local t = args[1] + local pos = args[2] + local val = args[3] + if c < 3 then + val = pos + pos = nil + end + -- set length indicator n if not present (+1) + t.n = (t.n or #t) + 1 + if not pos then + pos = t.n + elseif pos > t.n then + -- out of our range + t[pos] = val + t.n = pos + end + -- shift everything up 1 pos + for i = t.n, pos + 1, -1 do + t[i] = t[i - 1] + end + -- add element to be inserted + t[pos] = val +end +----------------------------------------------- +-- table.remove() replacement that respects nil values. +-- The function will use table field 'n' as indicator of the +-- table length, if not set, it will be added. +-- @param t table from which to remove +-- @param pos (optional) position in table to remove +-- @return No return values +function util.tremove(t, pos) + -- set length indicator n if not present (+1) + t.n = t.n or #t + if not pos then + pos = t.n + elseif pos > t.n then + local removed = t[pos] + -- out of our range + t[pos] = nil + return removed + end + local removed = t[pos] + -- shift everything up 1 pos + for i = pos, t.n do + t[i] = t[i + 1] + end + -- set size, clean last + t[t.n] = nil + t.n = t.n - 1 + return removed +end + +----------------------------------------------- +-- Checks an element to be callable. +-- The type must either be a function or have a metatable +-- containing an '__call' function. +-- @param object element to inspect on being callable or not +-- @return boolean, true if the object is callable +function util.callable(object) + return type(object) == "function" or type((debug.getmetatable(object) or {}).__call) == "function" +end +----------------------------------------------- +-- Checks an element has tostring. +-- The type must either be a string or have a metatable +-- containing an '__tostring' function. +-- @param object element to inspect on having tostring or not +-- @return boolean, true if the object has tostring +function util.hastostring(object) + return type(object) == "string" or type((debug.getmetatable(object) or {}).__tostring) == "function" +end + +----------------------------------------------- +-- Find the first level, not defined in the same file as the caller's +-- code file to properly report an error. +-- @param level the level to use as the caller's source file +-- @return number, the level of which to report an error +function util.errorlevel(level) + local level = (level or 1) + 1 -- add one to get level of the caller + local info = debug.getinfo(level) + local source = (info or {}).source + local file = source + while file and (file == source or source == "=(tail call)") do + level = level + 1 + info = debug.getinfo(level) + source = (info or {}).source + end + if level > 1 then + level = level - 1 + end -- deduct call to errorlevel() itself + return level +end + +----------------------------------------------- +-- Extract modifier and namespace keys from list of tokens. +-- @param nspace the namespace from which to match tokens +-- @param tokens list of tokens to search for keys +-- @return table, list of keys that were extracted +function util.extract_keys(nspace, tokens) + local namespace = require 'luassert.namespaces' + + -- find valid keys by coalescing tokens as needed, starting from the end + local keys = {} + local key = nil + local i = #tokens + while i > 0 do + local token = tokens[i] + key = key and (token .. '_' .. key) or token + + -- find longest matching key in the given namespace + local longkey = i > 1 and (tokens[i - 1] .. '_' .. key) or nil + while i > 1 and longkey and namespace[nspace][longkey] do + key = longkey + i = i - 1 + token = tokens[i] + longkey = (token .. '_' .. key) + end + + if namespace.modifier[key] or namespace[nspace][key] then + table.insert(keys, 1, key) + key = nil + end + i = i - 1 + end + + -- if there's anything left we didn't recognize it + if key then + error("luassert: unknown modifier/" .. nspace .. ": '" .. key .. "'", util.errorlevel(2)) + end + + return keys +end + +----------------------------------------------- +-- store argument list for return values of a function in a table. +-- The table will get a metatable to identify it as an arglist +function util.make_arglist(...) + local arglist = {...} + arglist.n = select('#', ...) -- add values count for trailing nils + return setmetatable(arglist, arglist_mt) +end + +----------------------------------------------- +-- check a table to be an arglist type. +function util.is_arglist(object) + return getmetatable(object) == arglist_mt +end + +return util