diff --git a/.gitignore b/.gitignore index 16055463..994b1700 100644 --- a/.gitignore +++ b/.gitignore @@ -27,5 +27,6 @@ conc_gen/* ccbse_gen/* headers/gensym/external.hpp grammar/.antlr/ +genwasym_runtime/build *.interp *.tokens diff --git a/genwasym_runtime/Makefile b/genwasym_runtime/Makefile new file mode 100644 index 00000000..56f8ee9b --- /dev/null +++ b/genwasym_runtime/Makefile @@ -0,0 +1,31 @@ +CXX = g++ +CXXFLAGS = -std=c++17 -Wall -Iinclude -I../third-party/immer -fPIC +BUILD_DIR = build + +SRC = lib/genwasym.cpp lib/wasm_state_continue.cpp +OBJ = build/genwasym.o build/wasm_state_continue.o + +STATIC_LIB = build/libgenwasym.a +SHARED_LIB = build/libgenwasym.so + +.PHONY: all clean + +all: build $(STATIC_LIB) $(SHARED_LIB) + +build: + mkdir -p $(BUILD_DIR) + +build/genwasym.o: lib/genwasym.cpp | build + $(CXX) $(CXXFLAGS) -c lib/genwasym.cpp -o build/genwasym.o + +build/wasm_state_continue.o: lib/wasm_state_continue.cpp | build + $(CXX) $(CXXFLAGS) -c lib/wasm_state_continue.cpp -o build/wasm_state_continue.o + +$(STATIC_LIB): $(OBJ) + ar rcs $(STATIC_LIB) $(OBJ) + +$(SHARED_LIB): $(OBJ) + $(CXX) -shared -Wl,-install_name,@rpath/libgenwasym.so -o $(SHARED_LIB) $(OBJ) + +clean: + rm -rf $(BUILD_DIR) \ No newline at end of file diff --git a/genwasym_runtime/include/genwasym.h b/genwasym_runtime/include/genwasym.h new file mode 100644 index 00000000..b7fa06a1 --- /dev/null +++ b/genwasym_runtime/include/genwasym.h @@ -0,0 +1,4 @@ +#pragma once + +int genwasym_dummy(); + diff --git a/genwasym_runtime/include/wasm.hpp b/genwasym_runtime/include/wasm.hpp new file mode 100644 index 00000000..897b2b21 --- /dev/null +++ b/genwasym_runtime/include/wasm.hpp @@ -0,0 +1,11 @@ +#ifndef WASM_HEADERS +#define WASM_HEADERS + +#include "wasm/concolic_driver.hpp" +#include "wasm/concrete_rt.hpp" +#include "wasm/controls.hpp" +#include "wasm/profile.hpp" +#include "wasm/sym_rt.hpp" +#include "wasm/utils.hpp" + +#endif \ No newline at end of file diff --git a/genwasym_runtime/include/wasm/concolic_driver.hpp b/genwasym_runtime/include/wasm/concolic_driver.hpp new file mode 100644 index 00000000..fb8ffb4e --- /dev/null +++ b/genwasym_runtime/include/wasm/concolic_driver.hpp @@ -0,0 +1,277 @@ +#ifndef CONCOLIC_DRIVER_HPP +#define CONCOLIC_DRIVER_HPP + +#include "concrete_rt.hpp" +#include "config.hpp" +#include "output_report.hpp" +#include "profile.hpp" +#include "smt_solver.hpp" +#include "sym_rt.hpp" +#include "utils.hpp" +#include "z3++.h" +#include +#include +#include +#include +#include +#include +#include +#include + +class ConcolicDriver { + friend class ManagedConcolicCleanup; + +public: + ConcolicDriver(std::function entrypoint, + std::optional tree_file, int branchCount) + : entrypoint(entrypoint), tree_file(tree_file) { + ExploreTree.true_branch_cov_map.assign(branchCount, false); + ExploreTree.false_branch_cov_map.assign(branchCount, false); + } + void run(); + +private: + void main_exploration_loop(); + std::optional get_new_input(); + std::vector> collect_all_unexplored_path_conds(); + std::function entrypoint; + std::optional tree_file; + std::vector work_list; + std::set visited; +}; + +class ManagedConcolicCleanup { + const ConcolicDriver &driver; + +public: + ManagedConcolicCleanup(const ConcolicDriver &driver) : driver(driver) {} + ~ManagedConcolicCleanup() { + // put any cleanup code that needs to be done after each execution here + + // Dump the explore tree if needed + if (driver.tree_file.has_value()) + ExploreTree.dump_graphviz(driver.tree_file.value()); + + // Profile.print_summary(); + } +}; + +static std::monostate reset_stacks(); + +// A PathFrontier represents the frontier of an unexplored path. From this +// frontier, we can explore the path by executing the program from the beginning +// with the model stored in QueryResult. +struct PathFrontier { + QueryResult query_result; + NodeBox *node; +}; + +class PathPicker { +public: + PathPicker(std::vector &unexplored_paths, + std::set &visited) + : unexplored_paths(unexplored_paths), visited(visited) {} + + virtual std::optional pick_path() = 0; + +protected: + std::vector &unexplored_paths; + std::set &visited; +}; + +class DefaultPathPicker : public PathPicker { +public: + DefaultPathPicker(std::vector &unexplored_paths, + std::set &visited) + : PathPicker(unexplored_paths, visited) {} + + std::optional pick_path() override { + NodeBox *node = unexplored_paths.back(); + unexplored_paths.pop_back(); + + if (visited.find(node) != visited.end()) { + return std::nullopt; + } else { + visited.insert(node); + } + + if (!node->isUnexplored()) { + // if it's not unexplored anymore, skip it + return std::nullopt; + } + + std::optional result; + { + ManagedTimer timer(TimeProfileKind::SOLVER_TOTAL); + auto cond = node->collect_path_conds(); + result = solver.solve_path_conds(cond, true); + } + if (!result.has_value()) { + GENSYM_INFO("Found an unreachable path, marking it as unreachable..."); + node->fillUnreachableNode(); + return std::nullopt; + } + return PathFrontier{result.value(), node}; + } +}; + +class RandomPathPicker : public PathPicker { +public: + RandomPathPicker(std::vector &unexplored_paths, + std::set &visited) + : PathPicker(unexplored_paths, visited) {} + std::optional pick_path() override { + ManagedTimer timer(TimeProfileKind::SOLVER_TOTAL); + + if (unexplored_paths.empty()) { + return std::nullopt; + } + std::vector> all_path_conds; + std::vector candidate_nodes; + + for (auto node : unexplored_paths) { + ManagedTimer timer(TimeProfileKind::COLLECT_PATH_CONDITIONS); + if (visited.find(node) != visited.end()) { + continue; + } + if (!node->isUnexplored()) { + // I suppose thse should not happen + // assert(false); + continue; + } + all_path_conds.push_back(node->collect_path_conds()); + candidate_nodes.push_back(node); + } + + auto result = solver.find_reachable_path_with_witness(all_path_conds, + candidate_nodes); + if (!result.has_value()) { + for (auto node : candidate_nodes) { + GENSYM_INFO("Found an unreachable path, marking it as unreachable..."); + node->fillUnreachableNode(); + } + unexplored_paths.clear(); + return std::nullopt; + } + return PathFrontier{.query_result = *result, .node = result->witness}; + } +}; + +inline void ConcolicDriver::main_exploration_loop() { + + // Register a collector to ExploreTree to add new nodes to work_list + ExploreTree.register_new_node_collector([&](NodeBox *new_node) { + if (std::find(work_list.begin(), work_list.end(), new_node) == + work_list.end()) + work_list.push_back(new_node); + }); + + assert(ExploreTree.get_root()->isUnexplored() && + "Before main loop, root should be unexplored!"); + work_list.push_back(ExploreTree.get_root()); + + PathPicker &&picker = DefaultPathPicker(work_list, visited); + + while (!work_list.empty()) { + if (INTERACTIVE_MODE) { + std::cout << "Press Enter to continue to the next path..." << std::endl; + std::cin.get(); + } + ManagedConcolicCleanup cleanup{*this}; + ManagedTimer timer(TimeProfileKind::MAIN_LOOP); + // Pick a frontier of an unexplored path from the work list + auto frontier = picker.pick_path(); + if (!frontier.has_value()) { + continue; + } + + auto &node = frontier.value().node; + + const NumMap &new_env = *frontier.value().query_result.map_box; + z3::model &model = frontier.value().query_result.model; + + // update global symbolic environment from SMT solved model + SymEnv.update(new_env); + try { + GENSYM_INFO("Now execute the program with symbolic environment: "); + GENSYM_INFO(SymEnv.to_string()); + auto snapshot = dynamic_cast(node->node.get()); + if (REUSE_SNAPSHOT && snapshot && snapshot->worth_to_reuse()) { + assert(REUSE_SNAPSHOT); + Profile.incr_fromsnapshot_count(); + auto snap = snapshot->get_snapshot(); + snap.resume_execution_by_model(node, model); + } else { + Profile.incr_restart_count(); + auto timer = ManagedTimer(TimeProfileKind::INSTR); + ExploreTree.reset_cursor(); + reset_stacks(); + CostManager.reset_timer(); + entrypoint(); + } + + GENSYM_INFO("Execution finished successfully"); + } catch (std::runtime_error &e) { + std::cout << "Caught runtime error: " << e.what() << std::endl; + ExploreTree.fillFailedNode(); + + if (std::string(e.what()) == "Symbolic assertion failed") { + GENSYM_INFO("Symbolic assertion failed, continuing to next path..."); + continue; + } + + GENSYM_INFO("Caught runtime error during execution"); + switch (EXPLORE_MODE) { + case ExploreMode::EarlyExit: + return; + case ExploreMode::ExitByCoverage: + if (ExploreTree.all_branch_covered()) { + GENSYM_INFO("All branches covered, exiting..."); + return; + } else { + GENSYM_INFO( + "Found a bug, but not all branches covered, continuing..."); + } + std::cout << e.what() << std::endl; + } + } +#if defined(RUN_ONCE) + return; +#endif + } +} + +inline std::vector> +ConcolicDriver::collect_all_unexplored_path_conds() { + std::vector> result; + for (auto node : work_list) { + if (node->isUnexplored()) { + result.push_back(node->collect_path_conds()); + } + } + return result; +} + +inline void ConcolicDriver::run() { + main_exploration_loop(); + auto overall = ExploreTree.read_current_overall_result(); + overall.print(); + Profile.print_summary(); + dump_all_summary_json(Profile, overall); +} + +static void start_concolic_execution_with( + std::function entrypoint, int branchCount) { + + const char *env_tree_file = std::getenv("TREE_FILE"); + + auto tree_file = + env_tree_file ? std::make_optional(env_tree_file) : std::nullopt; + + ConcolicDriver driver = ConcolicDriver( + [=]() { entrypoint(std::monostate{}); }, tree_file, branchCount); + driver.run(); + std::quick_exit(0); +} + +#endif // CONCOLIC_DRIVER_HPP \ No newline at end of file diff --git a/genwasym_runtime/include/wasm/concrete_num.hpp b/genwasym_runtime/include/wasm/concrete_num.hpp new file mode 100644 index 00000000..6a2eafc5 --- /dev/null +++ b/genwasym_runtime/include/wasm/concrete_num.hpp @@ -0,0 +1,989 @@ +#ifndef WASM_CONCRETE_NUM_HPP +#define WASM_CONCRETE_NUM_HPP +#include "wasm/profile.hpp" +#include "wasm/utils.hpp" +#include +#include + +struct Num { + Num(int64_t value) : value(value) {} + Num() : value(0) {} + int64_t value; + + int32_t toInt() const { return static_cast(value); } + uint32_t toUInt() const { return static_cast(value); } + int64_t toInt64() const { return static_cast(value); } + uint64_t toUInt64() const { return static_cast(value); } + float toF32() const { return *reinterpret_cast(&value); } + double toF64() const { return *reinterpret_cast(&value); } + + // debug printer: enabled only when -DDEBUG + static inline void debug_print(const char *op, const Num &a, const Num &b, + const Num &res) { +#ifdef DEBUG_OP + std::cout << "[Debug] " << op << ": lhs=" << static_cast(a.value) + << " rhs=" << static_cast(b.value) + << " -> res=" << static_cast(res.value) << std::endl; +#endif + } + + // Helper to create a Wasm Boolean result (1 or 0 as Num) + Num WasmBool(bool condition) const { + Num res(condition ? 1 : 0); + debug_print("WasmBool", *this, *this, res); + return res; + } + // TODO: support different bit width operations, for now we just assume all + // oprands are i32 + // i32.eq (Equals): *this == other + inline Num i32_eq(const Num &other) const { + Num res = WasmBool(this->toUInt() == other.toUInt()); + debug_print("i32.eq", *this, other, res); + return res; + } + + // i32.ne (Not Equals): *this != other + inline Num i32_ne(const Num &other) const { + Num res = WasmBool(this->toUInt() != other.toUInt()); + debug_print("i32.ne", *this, other, res); + return res; + } + + // i32.lt_s (Signed Less Than): *this < other + inline Num i32_lt_s(const Num &other) const { + Num res = WasmBool(this->toInt() < other.toInt()); + debug_print("i32.lt_s", *this, other, res); + return res; + } + + // i32.lt_u (Unsigned Less Than): *this < other (unsigned) + inline Num i32_lt_u(const Num &other) const { + Num res = WasmBool(this->toUInt() < other.toUInt()); + debug_print("i32.lt_u", *this, other, res); + return res; + } + + // i32.le_s (Signed Less Than or Equal): *this <= other + inline Num i32_le_s(const Num &other) const { + Num res = WasmBool(this->toInt() <= other.toInt()); + debug_print("i32.le_s", *this, other, res); + return res; + } + // i32.le_u (Unsigned Less Than or Equal): *this <= other (unsigned) + inline Num i32_le_u(const Num &other) const { + Num res = WasmBool(this->toUInt() <= other.toUInt()); + debug_print("i32.le_u", *this, other, res); + return res; + } + + // i32.gt_s (Signed Greater Than): *this > other + inline Num i32_gt_s(const Num &other) const { + Num res = WasmBool(this->toInt() > other.toInt()); + debug_print("i32.gt_s", *this, other, res); + return res; + } + + // i32.gt_u (Unsigned Greater Than): *this > other (unsigned) + inline Num i32_gt_u(const Num &other) const { + Num res = WasmBool(this->toUInt() > other.toUInt()); + debug_print("i32.gt_u", *this, other, res); + return res; + } + + // i32.ge_s (Signed Greater Than or Equal): *this >= other + inline Num i32_ge_s(const Num &other) const { + Num res = WasmBool(this->toInt() >= other.toInt()); + debug_print("i32.ge_s", *this, other, res); + return res; + } + + // i32.ge_u (Unsigned Greater Than or Equal): *this >= other (unsigned) + inline Num i32_ge_u(const Num &other) const { + Num res = WasmBool(this->toUInt() >= other.toUInt()); + debug_print("i32.ge_u", *this, other, res); + return res; + } + + // i32.add (Wrapping addition) + inline Num i32_add(const Num &other) const { + uint32_t result_u = this->toUInt() + other.toUInt(); + Num res(static_cast(result_u)); + debug_print("i32.add", *this, other, res); + return res; + } + + // i32.sub (Wrapping subtraction) + inline Num i32_sub(const Num &other) const { + uint32_t result_u = this->toUInt() - other.toUInt(); + Num res(static_cast(result_u)); + debug_print("i32.sub", *this, other, res); + return res; + } + + // i32.mul (Wrapping multiplication) + inline Num i32_mul(const Num &other) const { + uint32_t result_u = this->toUInt() * other.toUInt(); + Num res(static_cast(result_u)); + debug_print("i32.mul", *this, other, res); + return res; + } + + // i32.div_s (Signed division with traps) + inline Num i32_div_s(const Num &other) const { + int32_t divisor = other.toInt(); + int32_t dividend = this->toInt(); + + if (divisor == 0) { + throw std::runtime_error("i32.div_s: Division by zero"); + } + + Num res(dividend / divisor); + debug_print("i32.div_s", *this, other, res); + return res; + } + + // i32.div_u (Unsigned division with traps) + inline Num i32_div_u(const Num &other) const { + uint32_t divisor = other.toUInt(); + uint32_t dividend = this->toUInt(); + if (divisor == 0) { + throw std::runtime_error("i32.div_u: Division by zero"); + } + Num res(static_cast(dividend / divisor)); + debug_print("i32.div_u", *this, other, res); + return res; + } + + // i32.rem_s (Signed remainder with traps on division by zero) + inline Num i32_rem_s(const Num &other) const { + int32_t divisor = other.toInt(); + int32_t dividend = this->toInt(); + if (divisor == 0) { + throw std::runtime_error("i32.rem_s: Division by zero"); + } + // WebAssembly defines INT_MIN % -1 == 0 + if (dividend == INT32_MIN && divisor == -1) { + Num res(0); + debug_print("i32.rem_s", *this, other, res); + return res; + } + Num res(dividend % divisor); + debug_print("i32.rem_s", *this, other, res); + return res; + } + + // i32.rem_u (Unsigned remainder with traps on division by zero) + inline Num i32_rem_u(const Num &other) const { + uint32_t divisor = other.toUInt(); + uint32_t dividend = this->toUInt(); + if (divisor == 0) { + throw std::runtime_error("i32.rem_u: Division by zero"); + } + Num res(static_cast(dividend % divisor)); + debug_print("i32.rem_u", *this, other, res); + return res; + } + + // i32.shl (Shift Left): *this << other (shift count masked by 31) + inline Num i32_shl(const Num &other) const { + uint32_t shift_amount = other.toUInt() & 0x1F; + uint32_t result_u = toUInt() << shift_amount; + Num res(static_cast(result_u)); + debug_print("i32.shl", *this, other, res); + return res; + } + + // i32.shr_s (Signed Shift Right): *this >> other (Arithmetic shift) + inline Num i32_shr_s(const Num &other) const { + // Wasm masks the shift amount by 31 (0x1F) + uint32_t shift_amount = other.toUInt() & 0x1F; + int32_t result_s = toInt() >> shift_amount; + Num res(result_s); + debug_print("i32.shr_s", *this, other, res); + return res; + } + + // i32.shr_u (Unsigned Shift Right): *this >>> other (Logical shift) + inline Num i32_shr_u(const Num &other) const { + // Wasm masks the shift amount by 31 (0x1F) + uint32_t shift_amount = other.toUInt() & 0x1F; + uint32_t result_u = toUInt() >> shift_amount; + Num res(static_cast(result_u)); + debug_print("i32.shr_u", *this, other, res); + return res; + } + + // i32.and (Bitwise AND) + inline Num i32_and(const Num &other) const { + uint32_t result_u = this->toUInt() & other.toUInt(); + Num res(static_cast(result_u)); + debug_print("i32.and", *this, other, res); + return res; + } + + // i32.xor (Bitwise XOR) + inline Num i32_xor(const Num &other) const { + uint32_t result_u = this->toUInt() ^ other.toUInt(); + Num res(static_cast(result_u)); + debug_print("i32.xor", *this, other, res); + return res; + } + + inline Num i32_or(const Num &other) const { + uint32_t result_u = this->toUInt() | other.toUInt(); + Num res(static_cast(result_u)); + debug_print("i32.or", *this, other, res); + return res; + } + + // i64.extend_i32_s: sign-extend low 32 bits to i64 + inline Num i32_extend_to_i64_s() const { + int64_t result_s = static_cast(this->toInt()); + Num res(result_s); + debug_print("i32.extend_to_i64_s", *this, *this, res); + return res; + } + + // i64.extend_i32_u: zero-extend low 32 bits to i64 + inline Num i32_extend_to_i64_u() const { + uint64_t result_u = static_cast(this->toUInt()); + Num res(static_cast(result_u)); + debug_print("i32.extend_to_i64_u", *this, *this, res); + return res; + } + + // i64.eq (Equals): *this == other + inline Num i64_eq(const Num &other) const { + Num res = WasmBool(this->toUInt64() == other.toUInt64()); + debug_print("i64.eq", *this, other, res); + return res; + } + + // i64.ne (Not Equals): *this != other + inline Num i64_ne(const Num &other) const { + Num res = WasmBool(this->toUInt64() != other.toUInt64()); + debug_print("i64.ne", *this, other, res); + return res; + } + + // i64.lt_s (Signed Less Than): *this < other + inline Num i64_lt_s(const Num &other) const { + Num res = WasmBool(this->toInt64() < other.toInt64()); + debug_print("i64.lt_s", *this, other, res); + return res; + } + + // i64.lt_u (Unsigned Less Than): *this < other (unsigned) + inline Num i64_lt_u(const Num &other) const { + Num res = WasmBool(this->toUInt64() < other.toUInt64()); + debug_print("i64.lt_u", *this, other, res); + return res; + } + + // i64.le_s (Signed Less Than or Equal): *this <= other + inline Num i64_le_s(const Num &other) const { + Num res = WasmBool(this->toInt64() <= other.toInt64()); + debug_print("i64.le_s", *this, other, res); + return res; + } + + // i64.le_u (Unsigned Less Than or Equal): *this <= other (unsigned) + inline Num i64_le_u(const Num &other) const { + Num res = WasmBool(this->toUInt64() <= other.toUInt64()); + debug_print("i64.le_u", *this, other, res); + return res; + } + + // i64.gt_s (Signed Greater Than): *this > other + inline Num i64_gt_s(const Num &other) const { + Num res = WasmBool(this->toInt64() > other.toInt64()); + debug_print("i64.gt_s", *this, other, res); + return res; + } + + // i64.gt_u (Unsigned Greater Than): *this > other (unsigned) + inline Num i64_gt_u(const Num &other) const { + Num res = WasmBool(this->toUInt64() > other.toUInt64()); + debug_print("i64.gt_u", *this, other, res); + return res; + } + + // i64.ge_s (Signed Greater Than or Equal): *this >= other + inline Num i64_ge_s(const Num &other) const { + Num res = WasmBool(this->toInt64() >= other.toInt64()); + debug_print("i64.ge_s", *this, other, res); + return res; + } + + // i64.ge_u (Unsigned Greater Than or Equal): *this >= other (unsigned) + inline Num i64_ge_u(const Num &other) const { + Num res = WasmBool(this->toUInt64() >= other.toUInt64()); + debug_print("i64.ge_u", *this, other, res); + return res; + } + + // i64.add (Wrapping addition) + inline Num i64_add(const Num &other) const { + uint64_t result_u = this->toUInt64() + other.toUInt64(); + Num res(static_cast(result_u)); + debug_print("i64.add", *this, other, res); + return res; + } + + // i64.sub (Wrapping subtraction) + inline Num i64_sub(const Num &other) const { + uint64_t result_u = this->toUInt64() - other.toUInt64(); + Num res(static_cast(result_u)); + debug_print("i64.sub", *this, other, res); + return res; + } + + // i64.mul (Wrapping multiplication) + inline Num i64_mul(const Num &other) const { + uint64_t result_u = this->toUInt64() * other.toUInt64(); + Num res(static_cast(result_u)); + debug_print("i64.mul", *this, other, res); + return res; + } + + // i64.div_s (Signed division with traps) + inline Num i64_div_s(const Num &other) const { + int64_t divisor = other.toInt64(); + int64_t dividend = this->toInt64(); + + if (divisor == 0) { + throw std::runtime_error("i64.div_s: Division by zero"); + } + if (dividend == INT64_MIN && divisor == -1) { + throw std::runtime_error("i64.div_s: Integer overflow"); + } + + Num res(dividend / divisor); + debug_print("i64.div_s", *this, other, res); + return res; + } + + // i64.div_u (Unsigned division with traps) + inline Num i64_div_u(const Num &other) const { + uint64_t divisor = other.toUInt64(); + uint64_t dividend = this->toUInt64(); + if (divisor == 0) { + throw std::runtime_error("i64.div_u: Division by zero"); + } + Num res(static_cast(dividend / divisor)); + debug_print("i64.div_u", *this, other, res); + return res; + } + + // i64.rem_s (Signed remainder with traps on division by zero) + inline Num i64_rem_s(const Num &other) const { + int64_t divisor = other.toInt64(); + int64_t dividend = this->toInt64(); + if (divisor == 0) { + throw std::runtime_error("i64.rem_s: Division by zero"); + } + // WebAssembly defines INT64_MIN % -1 == 0 + if (dividend == INT64_MIN && divisor == -1) { + Num res(0); + debug_print("i64.rem_s", *this, other, res); + return res; + } + Num res(dividend % divisor); + debug_print("i64.rem_s", *this, other, res); + return res; + } + + // i64.rem_u (Unsigned remainder with traps on division by zero) + inline Num i64_rem_u(const Num &other) const { + uint64_t divisor = other.toUInt64(); + uint64_t dividend = this->toUInt64(); + if (divisor == 0) { + throw std::runtime_error("i64.rem_u: Division by zero"); + } + Num res(static_cast(dividend % divisor)); + debug_print("i64.rem_u", *this, other, res); + return res; + } + + // i64.shl (Shift Left): *this << other (shift count masked by 63) + inline Num i64_shl(const Num &other) const { + uint64_t shift_amount = other.toUInt64() & 0x3F; + uint64_t result_u = toUInt64() << shift_amount; + Num res(static_cast(result_u)); + debug_print("i64.shl", *this, other, res); + return res; + } + + // i64.shr_s (Signed Shift Right): *this >> other (Arithmetic shift) + inline Num i64_shr_s(const Num &other) const { + uint64_t shift_amount = other.toUInt64() & 0x3F; + int64_t result_s = toInt64() >> shift_amount; + Num res(result_s); + debug_print("i64.shr_s", *this, other, res); + return res; + } + + // i64.shr_u (Unsigned Shift Right): *this >>> other (Logical shift) + inline Num i64_shr_u(const Num &other) const { + uint64_t shift_amount = other.toUInt64() & 0x3F; + uint64_t result_u = toUInt64() >> shift_amount; + Num res(static_cast(result_u)); + debug_print("i64.shr_u", *this, other, res); + return res; + } + + // i64.and (Bitwise AND) + inline Num i64_and(const Num &other) const { + uint64_t result_u = this->toUInt64() & other.toUInt64(); + Num res(static_cast(result_u)); + debug_print("i64.and", *this, other, res); + return res; + } + + // i64.xor (Bitwise XOR) + inline Num i64_xor(const Num &other) const { + uint64_t result_u = this->toUInt64() ^ other.toUInt64(); + Num res(static_cast(result_u)); + debug_print("i64.xor", *this, other, res); + return res; + } + + // i64.or (Bitwise OR) + inline Num i64_or(const Num &other) const { + uint64_t result_u = this->toUInt64() | other.toUInt64(); + Num res(static_cast(result_u)); + debug_print("i64.or", *this, other, res); + return res; + } + + // f32 helpers: interpret low 32 bits of value as IEEE-754 float + static inline float f32_from_bits(uint32_t bits) { + union { + uint32_t i; + float f; + } u; + u.i = bits; + return u.f; + } + static inline uint32_t f32_to_bits(float f) { + union { + uint32_t i; + float f; + } u; + u.f = f; + return u.i; + } + static inline bool f32_is_nan(uint32_t bits) { + // Exponent all ones and mantissa non-zero -> NaN for IEEE-754 single + return (bits & 0x7F800000u) == 0x7F800000u && (bits & 0x007FFFFFu) != 0; + } + + // f32.add + inline Num f32_add(const Num &other) const { + uint32_t a_bits = toUInt(); + uint32_t b_bits = other.toUInt(); + float a = f32_from_bits(a_bits); + float b = f32_from_bits(b_bits); + float r = a + b; + uint32_t r_bits = f32_to_bits(r); + Num res(static_cast(r_bits)); + debug_print("f32.add", *this, other, res); + return res; + } + + // f32.sub + inline Num f32_sub(const Num &other) const { + uint32_t a_bits = toUInt(); + uint32_t b_bits = other.toUInt(); + float a = f32_from_bits(a_bits); + float b = f32_from_bits(b_bits); + float r = a - b; + uint32_t r_bits = f32_to_bits(r); + Num res(static_cast(r_bits)); + debug_print("f32.sub", *this, other, res); + return res; + } + + // f32.mul + inline Num f32_mul(const Num &other) const { + uint32_t a_bits = toUInt(); + uint32_t b_bits = other.toUInt(); + float a = f32_from_bits(a_bits); + float b = f32_from_bits(b_bits); + float r = a * b; + uint32_t r_bits = f32_to_bits(r); + Num res(static_cast(r_bits)); + debug_print("f32.mul", *this, other, res); + return res; + } + + // f32.div + inline Num f32_div(const Num &other) const { + uint32_t a_bits = toUInt(); + uint32_t b_bits = other.toUInt(); + float a = f32_from_bits(a_bits); + float b = f32_from_bits(b_bits); + float r = a / b; + uint32_t r_bits = f32_to_bits(r); + Num res(static_cast(r_bits)); + debug_print("f32.div", *this, other, res); + return res; + } + + // f32.eq : false if either is NaN + inline Num f32_eq(const Num &other) const { + uint32_t a_bits = toUInt(); + uint32_t b_bits = other.toUInt(); + if (f32_is_nan(a_bits) || f32_is_nan(b_bits)) { + Num res = WasmBool(false); + debug_print("f32.eq", *this, other, res); + return res; + } + float a = f32_from_bits(a_bits); + float b = f32_from_bits(b_bits); + Num res = WasmBool(a == b); + debug_print("f32.eq", *this, other, res); + return res; + } + + // f32.ne : true if values are unordered or not equal (i.e., NaN makes it + // true) + inline Num f32_ne(const Num &other) const { + uint32_t a_bits = toUInt(); + uint32_t b_bits = other.toUInt(); + // per wasm: if either is NaN, f32.ne is true + if (f32_is_nan(a_bits) || f32_is_nan(b_bits)) { + Num res = WasmBool(true); + debug_print("f32.ne", *this, other, res); + return res; + } + float a = f32_from_bits(a_bits); + float b = f32_from_bits(b_bits); + Num res = WasmBool(a != b); + debug_print("f32.ne", *this, other, res); + return res; + } + + // ordered comparisons: return false if any operand is NaN + inline Num f32_lt(const Num &other) const { + uint32_t a_bits = toUInt(), b_bits = other.toUInt(); + if (f32_is_nan(a_bits) || f32_is_nan(b_bits)) + return WasmBool(false); + float a = f32_from_bits(a_bits), b = f32_from_bits(b_bits); + Num res = WasmBool(a < b); + debug_print("f32.lt", *this, other, res); + return res; + } + inline Num f32_le(const Num &other) const { + uint32_t a_bits = toUInt(), b_bits = other.toUInt(); + if (f32_is_nan(a_bits) || f32_is_nan(b_bits)) + return WasmBool(false); + float a = f32_from_bits(a_bits), b = f32_from_bits(b_bits); + Num res = WasmBool(a <= b); + debug_print("f32.le", *this, other, res); + return res; + } + inline Num f32_gt(const Num &other) const { + uint32_t a_bits = toUInt(), b_bits = other.toUInt(); + if (f32_is_nan(a_bits) || f32_is_nan(b_bits)) + return WasmBool(false); + float a = f32_from_bits(a_bits), b = f32_from_bits(b_bits); + Num res = WasmBool(a > b); + debug_print("f32.gt", *this, other, res); + return res; + } + inline Num f32_ge(const Num &other) const { + uint32_t a_bits = toUInt(), b_bits = other.toUInt(); + if (f32_is_nan(a_bits) || f32_is_nan(b_bits)) + return WasmBool(false); + float a = f32_from_bits(a_bits), b = f32_from_bits(b_bits); + Num res = WasmBool(a >= b); + debug_print("f32.ge", *this, other, res); + return res; + } + + // f32.abs: clear sign bit + inline Num f32_abs() const { + uint32_t a_bits = toUInt(); + uint32_t r_bits = a_bits & 0x7FFFFFFFu; + Num res(static_cast(r_bits)); + debug_print("f32.abs", *this, *this, res); + return res; + } + + // f32.neg: flip sign bit + inline Num f32_neg() const { + uint32_t a_bits = toUInt(); + uint32_t r_bits = a_bits ^ 0x80000000u; + Num res(static_cast(r_bits)); + debug_print("f32.neg", *this, *this, res); + return res; + } + + inline Num convert_i32_to_f32_s() const { + uint32_t r_bits = f32_to_bits(static_cast(toInt())); + return Num(static_cast(r_bits)); + } + + inline Num convert_i32_to_f32_u() const { + uint32_t r_bits = f32_to_bits(static_cast(toUInt())); + return Num(static_cast(r_bits)); + } + + inline Num convert_i64_to_f32_s() const { + uint32_t r_bits = f32_to_bits(static_cast(toInt64())); + return Num(static_cast(r_bits)); + } + + inline Num convert_i64_to_f32_u() const { + uint32_t r_bits = f32_to_bits(static_cast(toUInt64())); + return Num(static_cast(r_bits)); + } + + // f32.min / f32.max: follow wasm-ish semantics: if either is NaN, return NaN + // (propagate) + inline Num f32_min(const Num &other) const { + uint32_t a_bits = toUInt(), b_bits = other.toUInt(); + if (f32_is_nan(a_bits)) + return Num(static_cast(a_bits)); + if (f32_is_nan(b_bits)) + return Num(static_cast(b_bits)); + float a = f32_from_bits(a_bits), b = f32_from_bits(b_bits); + // If values compare equal choose one to preserve signed zero: pick the one + // whose sign bit is set for min when both zeros (so -0 wins for min). + if (a == b) { + if ((a_bits & 0x80000000u) || (b_bits & 0x80000000u)) + return Num( + static_cast((a_bits & 0x80000000u) ? a_bits : b_bits)); + return Num(static_cast(a_bits)); + } + float r = (a < b) ? a : b; + uint32_t r_bits = f32_to_bits(r); + Num res(static_cast(r_bits)); + debug_print("f32.min", *this, other, res); + return res; + } + + inline Num f32_max(const Num &other) const { + uint32_t a_bits = toUInt(), b_bits = other.toUInt(); + if (f32_is_nan(a_bits)) + return Num(static_cast(a_bits)); + if (f32_is_nan(b_bits)) + return Num(static_cast(b_bits)); + float a = f32_from_bits(a_bits), b = f32_from_bits(b_bits); + if (a == b) { + if ((a_bits & 0x80000000u) || (b_bits & 0x80000000u)) + return Num( + static_cast((a_bits & 0x80000000u) ? b_bits : a_bits)); + return Num(static_cast(a_bits)); + } + float r = (a > b) ? a : b; + uint32_t r_bits = f32_to_bits(r); + Num res(static_cast(r_bits)); + debug_print("f32.max", *this, other, res); + return res; + } + + // f32.copysign: result has magnitude of lhs, sign of rhs + inline Num f32_copysign(const Num &other) const { + uint32_t a_bits = toUInt(), b_bits = other.toUInt(); + uint32_t r_bits = (a_bits & 0x7FFFFFFFu) | (b_bits & 0x80000000u); + Num res(static_cast(r_bits)); + debug_print("f32.copysign", *this, other, res); + return res; + } + + // f64 helpers: interpret all 64 bits of value as IEEE-754 double + static inline double f64_from_bits(uint64_t bits) { + union { + uint64_t i; + double d; + } u; + u.i = bits; + return u.d; + } + static inline uint64_t f64_to_bits(double d) { + union { + uint64_t i; + double d; + } u; + u.d = d; + return u.i; + } + static inline bool f64_is_nan(uint64_t bits) { + // Exponent all ones and mantissa non-zero -> NaN for IEEE-754 double + return (bits & 0x7FF0000000000000ull) == 0x7FF0000000000000ull && + (bits & 0x000FFFFFFFFFFFFFull) != 0; + } + + // f64.add + inline Num f64_add(const Num &other) const { + uint64_t a_bits = toUInt64(); + uint64_t b_bits = other.toUInt64(); + double a = f64_from_bits(a_bits); + double b = f64_from_bits(b_bits); + double r = a + b; + uint64_t r_bits = f64_to_bits(r); + Num res(static_cast(r_bits)); + debug_print("f64.add", *this, other, res); + return res; + } + + // f64.sub + inline Num f64_sub(const Num &other) const { + uint64_t a_bits = toUInt64(); + uint64_t b_bits = other.toUInt64(); + double a = f64_from_bits(a_bits); + double b = f64_from_bits(b_bits); + double r = a - b; + uint64_t r_bits = f64_to_bits(r); + Num res(static_cast(r_bits)); + debug_print("f64.sub", *this, other, res); + return res; + } + + // f64.mul + inline Num f64_mul(const Num &other) const { + uint64_t a_bits = toUInt64(); + uint64_t b_bits = other.toUInt64(); + double a = f64_from_bits(a_bits); + double b = f64_from_bits(b_bits); + double r = a * b; + uint64_t r_bits = f64_to_bits(r); + Num res(static_cast(r_bits)); + debug_print("f64.mul", *this, other, res); + return res; + } + + // f64.div + inline Num f64_div(const Num &other) const { + uint64_t a_bits = toUInt64(); + uint64_t b_bits = other.toUInt64(); + double a = f64_from_bits(a_bits); + double b = f64_from_bits(b_bits); + double r = a / b; + uint64_t r_bits = f64_to_bits(r); + Num res(static_cast(r_bits)); + debug_print("f64.div", *this, other, res); + return res; + } + + // f64.eq : false if either is NaN + inline Num f64_eq(const Num &other) const { + uint64_t a_bits = toUInt64(); + uint64_t b_bits = other.toUInt64(); + if (f64_is_nan(a_bits) || f64_is_nan(b_bits)) { + Num res = WasmBool(false); + debug_print("f64.eq", *this, other, res); + return res; + } + double a = f64_from_bits(a_bits); + double b = f64_from_bits(b_bits); + Num res = WasmBool(a == b); + debug_print("f64.eq", *this, other, res); + return res; + } + + // f64.ne : true if values are unordered or not equal (i.e., NaN makes it + // true) + inline Num f64_ne(const Num &other) const { + uint64_t a_bits = toUInt64(); + uint64_t b_bits = other.toUInt64(); + // per wasm: if either is NaN, f64.ne is true + if (f64_is_nan(a_bits) || f64_is_nan(b_bits)) { + Num res = WasmBool(true); + debug_print("f64.ne", *this, other, res); + return res; + } + double a = f64_from_bits(a_bits); + double b = f64_from_bits(b_bits); + Num res = WasmBool(a != b); + debug_print("f64.ne", *this, other, res); + return res; + } + + // ordered comparisons: return false if any operand is NaN + inline Num f64_lt(const Num &other) const { + uint64_t a_bits = toUInt64(), b_bits = other.toUInt64(); + if (f64_is_nan(a_bits) || f64_is_nan(b_bits)) + return WasmBool(false); + double a = f64_from_bits(a_bits), b = f64_from_bits(b_bits); + Num res = WasmBool(a < b); + debug_print("f64.lt", *this, other, res); + return res; + } + inline Num f64_le(const Num &other) const { + uint64_t a_bits = toUInt64(), b_bits = other.toUInt64(); + if (f64_is_nan(a_bits) || f64_is_nan(b_bits)) + return WasmBool(false); + double a = f64_from_bits(a_bits), b = f64_from_bits(b_bits); + Num res = WasmBool(a <= b); + debug_print("f64.le", *this, other, res); + return res; + } + inline Num f64_gt(const Num &other) const { + uint64_t a_bits = toUInt64(), b_bits = other.toUInt64(); + if (f64_is_nan(a_bits) || f64_is_nan(b_bits)) + return WasmBool(false); + double a = f64_from_bits(a_bits), b = f64_from_bits(b_bits); + Num res = WasmBool(a > b); + debug_print("f64.gt", *this, other, res); + return res; + } + inline Num f64_ge(const Num &other) const { + uint64_t a_bits = toUInt64(), b_bits = other.toUInt64(); + if (f64_is_nan(a_bits) || f64_is_nan(b_bits)) + return WasmBool(false); + double a = f64_from_bits(a_bits), b = f64_from_bits(b_bits); + Num res = WasmBool(a >= b); + debug_print("f64.ge", *this, other, res); + return res; + } + + // f64.abs: clear sign bit + inline Num f64_abs() const { + uint64_t a_bits = toUInt64(); + uint64_t r_bits = a_bits & 0x7FFFFFFFFFFFFFFFull; + Num res(static_cast(r_bits)); + debug_print("f64.abs", *this, *this, res); + return res; + } + + // f64.neg: flip sign bit + inline Num f64_neg() const { + uint64_t a_bits = toUInt64(); + uint64_t r_bits = a_bits ^ 0x8000000000000000ull; + Num res(static_cast(r_bits)); + debug_print("f64.neg", *this, *this, res); + return res; + } + + inline Num convert_i32_to_f64_s() const { + uint64_t r_bits = f64_to_bits(static_cast(toInt())); + return Num(static_cast(r_bits)); + } + + inline Num convert_i32_to_f64_u() const { + uint64_t r_bits = f64_to_bits(static_cast(toUInt())); + return Num(static_cast(r_bits)); + } + + inline Num convert_i64_to_f64_s() const { + uint64_t r_bits = f64_to_bits(static_cast(toInt64())); + return Num(static_cast(r_bits)); + } + + inline Num convert_i64_to_f64_u() const { + uint64_t r_bits = f64_to_bits(static_cast(toUInt64())); + return Num(static_cast(r_bits)); + } + + inline Num trunc_f64_to_i32_u() const { + uint64_t bits = toUInt64(); + double value = f64_from_bits(bits); + + if (std::isnan(value)) { + throw std::runtime_error("i32.trunc_f64_u: NaN"); + } + if (std::isinf(value)) { + throw std::runtime_error("i32.trunc_f64_u: Infinity"); + } + if (value < 0.0 || value >= 4294967296.0) { + throw std::runtime_error("i32.trunc_f64_u: Out of range"); + } + + double truncated = std::trunc(value); + uint32_t result = static_cast(truncated); + Num res(static_cast(result)); + debug_print("i32.trunc_f64_u", *this, *this, res); + return res; + } + + // f64.min / f64.max: follow wasm-ish semantics: if either is NaN, return + // NaN (propagate) + inline Num f64_min(const Num &other) const { + uint64_t a_bits = toUInt64(), b_bits = other.toUInt64(); + if (f64_is_nan(a_bits)) + return Num(static_cast(a_bits)); + if (f64_is_nan(b_bits)) + return Num(static_cast(b_bits)); + double a = f64_from_bits(a_bits), b = f64_from_bits(b_bits); + // If values compare equal choose one to preserve signed zero: pick the one + // whose sign bit is set for min when both zeros (so -0 wins for min). + if (a == b) { + if ((a_bits & 0x8000000000000000ull) || (b_bits & 0x8000000000000000ull)) + return Num(static_cast( + (a_bits & 0x8000000000000000ull) ? a_bits : b_bits)); + return Num(static_cast(a_bits)); + } + double r = (a < b) ? a : b; + uint64_t r_bits = f64_to_bits(r); + Num res(static_cast(r_bits)); + debug_print("f64.min", *this, other, res); + return res; + } + + inline Num f64_max(const Num &other) const { + uint64_t a_bits = toUInt64(), b_bits = other.toUInt64(); + if (f64_is_nan(a_bits)) + return Num(static_cast(a_bits)); + if (f64_is_nan(b_bits)) + return Num(static_cast(b_bits)); + double a = f64_from_bits(a_bits), b = f64_from_bits(b_bits); + if (a == b) { + if ((a_bits & 0x8000000000000000ull) || (b_bits & 0x8000000000000000ull)) + return Num(static_cast( + (a_bits & 0x8000000000000000ull) ? b_bits : a_bits)); + return Num(static_cast(a_bits)); + } + double r = (a > b) ? a : b; + uint64_t r_bits = f64_to_bits(r); + Num res(static_cast(r_bits)); + debug_print("f64.max", *this, other, res); + return res; + } + + // f64.copysign: result has magnitude of lhs, sign of rhs + inline Num f64_copysign(const Num &other) const { + uint64_t a_bits = toUInt64(), b_bits = other.toUInt64(); + uint64_t r_bits = + (a_bits & 0x7FFFFFFFFFFFFFFFull) | (b_bits & 0x8000000000000000ull); + Num res(static_cast(r_bits)); + debug_print("f64.copysign", *this, other, res); + return res; + } + + // logic and + inline bool logical_and(const Num &other) const { + return (this->toUInt() != 0) && (other.toUInt() != 0); + } + + // logic or + inline bool logical_or(const Num &other) const { + return (this->toUInt() != 0) || (other.toUInt() != 0); + } +}; + +static Num I32V(int v) { return v; } + +static Num I64V(int64_t v) { return v; } + +static Num F32V(float f) { + union { + uint32_t i; + float f; + } u; + u.f = f; + return static_cast(u.i); +} + +static Num F64V(double d) { + union { + uint64_t i; + double d; + } u; + u.d = d; + return static_cast(u.i); +} + +#endif // WASM_CONCRETE_NUM_HPP diff --git a/genwasym_runtime/include/wasm/concrete_rt.hpp b/genwasym_runtime/include/wasm/concrete_rt.hpp new file mode 100644 index 00000000..3e54baac --- /dev/null +++ b/genwasym_runtime/include/wasm/concrete_rt.hpp @@ -0,0 +1,497 @@ +#ifndef WASM_CONCRETE_RT_HPP +#define WASM_CONCRETE_RT_HPP + +#include "concrete_num.hpp" +#include "controls.hpp" +#include "immer/vector_transient.hpp" +#include "wasm/profile.hpp" +#include "wasm/utils.hpp" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +const int STACK_SIZE = 1024 * 64; + +class Stack_t { +public: + Stack_t() : count(0), stack_ptr(new Num[STACK_SIZE]) { + size_t page_size = (size_t)sysconf(_SC_PAGESIZE); + // pre touch the memory to avoid page faults during execution + for (int i = 0; i < STACK_SIZE; i += page_size) { + stack_ptr[i] = Num(0); + } + } + + std::monostate push(Num &&num) { +#ifdef DEBUG + printf("[Debug] pushing a value %ld to stack, size of concrete stack is: " + "%d\n", + num.value, count); +#endif + Profile.step(StepProfileKind::PUSH); + stack_ptr[count] = num; + count++; + return std::monostate{}; + } + + std::monostate push(Num &num) { + Profile.step(StepProfileKind::PUSH); + stack_ptr[count] = num; + count++; + return std::monostate{}; + } + + Num pop() { + Profile.step(StepProfileKind::POP); + assert(count > 0 && "Stack underflow"); +#ifdef DEBUG + printf("[Debug] popping a value %ld from stack, size of concrete stack is: " + "%d\n", + stack_ptr[count - 1].value, count); +#endif + Num num = stack_ptr[count - 1]; + count--; + return num; + } + + Num peek() { + Profile.step(StepProfileKind::PEEK); +#ifdef DEBUG + if (count == 0) { + throw std::runtime_error("Stack underflow"); + } +#endif + return stack_ptr[count - 1]; + } + + int32_t size() { return count; } + + void shift(int32_t offset, int32_t size) { + Profile.step(StepProfileKind::SHIFT); +#ifdef DEBUG + if (offset < 0) { + throw std::out_of_range("Invalid offset: " + std::to_string(offset)); + } + if (size < 0) { + throw std::out_of_range("Invalid size: " + std::to_string(size)); + } + std::cout << "Shifting stack by offset " << offset << " and size " << size + << std::endl; + std::cout << "Current stack size: " << count << std::endl; +#endif + // shift last `size` of numbers forward of `offset` + for (int32_t i = count - size; i < count; ++i) { + assert(i - offset >= 0); + stack_ptr[i - offset] = stack_ptr[i]; + } + count -= offset; + } + + void print() { + std::cout << "Stack contents: " << std::endl; + for (int32_t i = 0; i < count; ++i) { + std::cout << stack_ptr[count - i - 1].value << std::endl; + } + std::cout << "End of Stack contents" << std::endl; + } + + void initialize() { + // todo: remove this method + reset(); + } + + void reset() { count = 0; } + + void resize(int32_t new_size) { + assert(new_size >= 0); + count = new_size; + } + + void set_from_front(int32_t index, const Num &num) { + assert(index >= 0 && index < count); + stack_ptr[index] = num; + } + +private: + int32_t count; + Num *stack_ptr; +}; +static Stack_t Stack; +class SymFrames_t; + +const int FRAME_SIZE = 1024 * 8; +class Frames_t { +public: + Frames_t() : count(0), stack_ptr(new Num[FRAME_SIZE]), frame_ptrs() { + size_t page_size = (size_t)sysconf(_SC_PAGESIZE); + // pre touch the memory to avoid page faults during execution + for (int i = 0; i < FRAME_SIZE; i += page_size) { + stack_ptr[i] = Num(0); + } + } + + std::monostate popFrameCaller(std::int32_t size) { + assert(size >= 0); + assert(size <= count); + assert(!frame_ptrs.empty()); + auto frame_base = current_frame_base(); + assert(frame_base + size == count); + count -= size; +#ifdef USE_IMM + frame_ptrs.take(frame_ptrs.size() - 1); +#else + frame_ptrs.pop_back(); +#endif + return std::monostate{}; + } + + std::monostate popFrameCallee(std::int32_t size) { + assert(size >= 0); + assert(size <= count); + count -= size; + return std::monostate{}; + } + + Num get(std::int32_t index) { + assert(!frame_ptrs.empty() && "No active frame"); + auto frame_base = current_frame_base(); + assert(index >= 0 && frame_base + index < count && "Index out of bounds"); + Profile.step(StepProfileKind::GET); + auto ret = stack_ptr[frame_base + index]; + return ret; + } + + void set(std::int32_t index, Num num) { + assert(!frame_ptrs.empty() && "No active frame"); + auto frame_base = current_frame_base(); + assert(index >= 0 && frame_base + index < count && "Index out of bounds"); + Profile.step(StepProfileKind::SET); + stack_ptr[frame_base + index] = num; + } + + void pushFrameCaller(std::int32_t size) { + assert(size >= 0); + frame_ptrs.push_back(count); + count += size; + // Zero-initialize the new stack frames. + for (std::int32_t i = 0; i < size; ++i) { + stack_ptr[count - size + i] = Num(0); + } + } + + void pushFrameCallee(std::int32_t size) { + assert(size >= 0); + assert(!frame_ptrs.empty() && "No active frame"); + auto old_count = count; + count += size; + for (std::int32_t i = 0; i < size; ++i) { + stack_ptr[old_count + i] = Num(0); + } + } + + void reset() { + count = 0; +#ifdef USE_IMM + frame_ptrs = immer::vector_transient(); +#else + frame_ptrs.clear(); +#endif + } + + size_t size() const { return count; } + + void set_from_front(int32_t index, const Num &num) { + assert(index >= 0 && index < count && "Index out of bounds"); + stack_ptr[index] = num; + } + + void resize(int32_t new_size) { + assert(new_size >= 0); + count = new_size; + } + +private: + friend class SymFrames_t; + + size_t current_frame_base() const { +#ifdef USE_IMM + return *(frame_ptrs.end() - 1); +#else + return frame_ptrs.back(); +#endif + } + + int32_t count; + Num *stack_ptr; +#ifdef USE_IMM + immer::vector_transient frame_ptrs; +#else + std::vector frame_ptrs; +#endif +}; + +static Frames_t Frames; +static Frames_t Globals; + +static void initRand() { + // for now, just do nothing +} + +static std::monostate unreachable() { + std::cout << "Unreachable code reached!" << std::endl; + throw std::runtime_error("Unreachable code reached"); +} + +static const int PRE_ALLOC_PAGES = 20; +static int32_t pagesize = 65536; +static int32_t page_count = 0; + +struct Memory_t { + + Memory_t(int32_t init_page_count) + : memory(PRE_ALLOC_PAGES * pagesize), init_page_count(init_page_count), + page_count(init_page_count), allocated_pages(PRE_ALLOC_PAGES) { + reset(); + } + + int32_t loadInt(int32_t base, int32_t offset) { + int32_t addr = base + offset; + if (!(addr + 3 < memory.size()) || addr < 0) { + throw std::runtime_error("Invalid memory access " + std::to_string(addr)); + } + int32_t result = 0; + // Little-endian: lowest byte at lowest address + for (int i = 0; i < 4; ++i) { + result |= static_cast(memory[addr + i]) << (8 * i); + } +#ifdef DEBUG + std::cout << "[Debug] loading int " << result << " from memory at address " + << addr << std::endl; + +#endif + // just load a 4-byte integer from memory of the vector + return result; + } + + uint8_t loadByte(int32_t base, int32_t offset) { + int32_t addr = base + offset; + if (!(addr < memory.size()) || addr < 0) { + throw std::runtime_error("Invalid memory access " + std::to_string(addr)); + } + return memory[addr]; + } + + int32_t loadInt8U(int32_t base, int32_t offset) { + return static_cast(loadByte(base, offset)); + } + + int32_t loadInt8S(int32_t base, int32_t offset) { + return static_cast(loadByte(base, offset)); + } + + int32_t loadInt16U(int32_t base, int32_t offset) { + uint32_t b0 = static_cast(loadByte(base, offset)); + uint32_t b1 = static_cast(loadByte(base, offset + 1)); + return static_cast(b0 | (b1 << 8)); + } + + int32_t loadInt16S(int32_t base, int32_t offset) { + uint32_t b0 = static_cast(loadByte(base, offset)); + uint32_t b1 = static_cast(loadByte(base, offset + 1)); + uint16_t raw = static_cast(b0 | (b1 << 8)); + return static_cast(raw); + } + + int64_t loadLong8U(int32_t base, int32_t offset) { + return static_cast(loadByte(base, offset)); + } + + int64_t loadLong8S(int32_t base, int32_t offset) { + return static_cast(loadByte(base, offset)); + } + + int64_t loadLong16U(int32_t base, int32_t offset) { + uint64_t b0 = static_cast(loadByte(base, offset)); + uint64_t b1 = static_cast(loadByte(base, offset + 1)); + return static_cast(b0 | (b1 << 8)); + } + + int64_t loadLong16S(int32_t base, int32_t offset) { + uint64_t b0 = static_cast(loadByte(base, offset)); + uint64_t b1 = static_cast(loadByte(base, offset + 1)); + uint16_t raw = static_cast(b0 | (b1 << 8)); + return static_cast(raw); + } + + int64_t loadLong32U(int32_t base, int32_t offset) { + uint64_t b0 = static_cast(loadByte(base, offset)); + uint64_t b1 = static_cast(loadByte(base, offset + 1)); + uint64_t b2 = static_cast(loadByte(base, offset + 2)); + uint64_t b3 = static_cast(loadByte(base, offset + 3)); + return static_cast(b0 | (b1 << 8) | (b2 << 16) | (b3 << 24)); + } + + int64_t loadLong32S(int32_t base, int32_t offset) { + uint64_t b0 = static_cast(loadByte(base, offset)); + uint64_t b1 = static_cast(loadByte(base, offset + 1)); + uint64_t b2 = static_cast(loadByte(base, offset + 2)); + uint64_t b3 = static_cast(loadByte(base, offset + 3)); + uint32_t raw = static_cast(b0 | (b1 << 8) | (b2 << 16) | (b3 << 24)); + return static_cast(raw); + } + + int64_t loadLong(int32_t base, int32_t offset) { + int32_t addr = base + offset; + if (!(addr + 7 < memory.size()) || addr < 0) { + throw std::runtime_error("Invalid memory access " + std::to_string(addr)); + } + int64_t result = 0; + for (int i = 0; i < 8; ++i) { + result |= static_cast(memory[addr + i]) << (8 * i); + } + return result; + } + + std::monostate storeInt(int32_t base, int32_t offset, int32_t value) { + int32_t addr = base + offset; + // Ensure we don't write out of bounds + if (!(addr + 3 < memory.size())) { + throw std::runtime_error("Invalid memory access " + std::to_string(addr)); + } + for (int i = 0; i < 4; ++i) { + memory[addr + i] = static_cast((value >> (8 * i)) & 0xFF); + } +#ifdef DEBUG + std::cout << "[Debug] storing int " << value << " to memory at address " + << addr << std::endl; +#endif + return std::monostate{}; + } + + std::monostate storeLong(int32_t base, int32_t offset, int64_t value) { + int32_t addr = base + offset; + if (!(addr + 7 < memory.size()) || addr < 0) { + throw std::runtime_error("Invalid memory access " + std::to_string(addr)); + } + for (int i = 0; i < 8; ++i) { + memory[addr + i] = static_cast((static_cast(value) >> (8 * i)) & 0xFF); + } + return std::monostate{}; + } + + std::monostate storeInt8(int32_t base, int32_t offset, int32_t value) { + return store_byte(base + offset, static_cast(value & 0xFF)); + } + + std::monostate storeInt16(int32_t base, int32_t offset, int32_t value) { + store_byte(base + offset, static_cast(value & 0xFF)); + store_byte(base + offset + 1, static_cast((value >> 8) & 0xFF)); + return std::monostate{}; + } + + std::monostate storeLong8(int32_t base, int32_t offset, int64_t value) { + return store_byte(base + offset, static_cast(value & 0xFF)); + } + + std::monostate storeLong16(int32_t base, int32_t offset, int64_t value) { + store_byte(base + offset, static_cast(value & 0xFF)); + store_byte(base + offset + 1, static_cast((value >> 8) & 0xFF)); + return std::monostate{}; + } + + std::monostate storeLong32(int32_t base, int32_t offset, int64_t value) { + store_byte(base + offset, static_cast(value & 0xFF)); + store_byte(base + offset + 1, static_cast((value >> 8) & 0xFF)); + store_byte(base + offset + 2, static_cast((value >> 16) & 0xFF)); + store_byte(base + offset + 3, static_cast((value >> 24) & 0xFF)); + return std::monostate{}; + } + + std::monostate store_byte(int32_t addr, uint8_t value) { +#ifdef DEBUG + std::cout << "[Debug] storing byte " << std::to_string(value) + << " to memory at address " << addr << std::endl; +#endif + assert(addr < memory.size()); + memory[addr] = value; + return std::monostate{}; + } + + // grow memory by delta bytes when bytes > 0. return -1 if failed, return old + // size when success + int32_t grow(int32_t delta) { + Profile.step(StepProfileKind::MEM_GROW); + if (delta <= 0) { + return page_count * pagesize; + } + + if (page_count + delta < allocated_pages) { + page_count += delta; + return page_count * pagesize; + } + + try { + assert(false && "Use pre-allocated memory, should not reach here"); + memory.resize(memory.size() + delta * pagesize); + auto old_page_count = page_count; + page_count += delta; + return memory.size(); + } catch (const std::bad_alloc &e) { + return -1; + } + } + + void reset() { + page_count = init_page_count; + allocated_pages = PRE_ALLOC_PAGES; + for (int i = 0; i < memory.size() && i < page_count * pagesize; ++i) { + memory[i] = 0; + } + } + +private: + std::vector memory; + int init_page_count; + int page_count; + int allocated_pages; +}; + +static Memory_t Memory(4); // 4 page memory + +struct FuncTable_t { + FuncTable_t() : table(20) {} + std::vector table; + + Func_t read(int32_t index) { + if (index < 0 || index >= table.size()) { + throw std::runtime_error("Function table read out of bounds: " + + std::to_string(index)); + } + if (!table[index]) { + assert(false); + throw std::runtime_error("Function table entry at index " + + std::to_string(index) + " is empty or invalid"); + } + return table[index]; + } + + std::monostate set(Num offset, int32_t index, Func_t func) { + if (index < 0 || index >= table.size()) { + throw std::runtime_error("Function table set out of bounds: " + + std::to_string(index)); + } + table[offset.toInt() + index] = func; + return std::monostate{}; + } +}; + +static FuncTable_t FuncTable; + +#endif // WASM_CONCRETE_RT_HPP diff --git a/genwasym_runtime/include/wasm/config.hpp b/genwasym_runtime/include/wasm/config.hpp new file mode 100644 index 00000000..6b64cbaf --- /dev/null +++ b/genwasym_runtime/include/wasm/config.hpp @@ -0,0 +1,82 @@ +#ifndef CONFIG_HPP +#define CONFIG_HPP + +// This file contains configuration settings for the concolic execution + +// If ENABLE_PROFILE_STEP defined, the compiled program will collect and print +// profiling how much steps of each data structure's operations are executed +#ifdef ENABLE_PROFILE_STEP +const bool PROFILE_STEP = true; +#else +const bool PROFILE_STEP = false; +#endif + +// If ENABLE_PROFILE_TIME defined, the compiled program will collect and print +// the profile of time spent in main loop and constraint solving +#ifdef ENABLE_PROFILE_TIME +const bool PROFILE_TIME = true; +#else +const bool PROFILE_TIME = false; +#endif + +#ifdef ENABLE_PROFILE_Z3_API_CALL +const bool PROFILE_Z3_API_CALL = true; +#else +const bool PROFILE_Z3_API_CALL = false; +#endif + +#ifdef ENABLE_PROFILE_CACHE +const bool PROFILE_CACHE = true; +#else +const bool PROFILE_CACHE = false; +#endif + +#ifdef ENABLE_PROFILE_PATH_CONDS +const bool PROFILE_PATH_CONDS = true; +#else +const bool PROFILE_PATH_CONDS = false; +#endif + +// This variable define when concolic execution will stop +enum class ExploreMode { + EarlyExit, // Stop at the first error encountered + + ExitByCoverage // Exit when all syntactic branches are covered +}; + +#ifdef EARLY_EXIT +static const ExploreMode EXPLORE_MODE = ExploreMode::EarlyExit; +#elif defined(BY_COVERAGE) +static const ExploreMode EXPLORE_MODE = ExploreMode::ExitByCoverage; +#else +static const ExploreMode EXPLORE_MODE = ExploreMode::EarlyExit; +#endif + +// This variable decides whether we enable the snapshot reuse optimization +#ifdef NO_REUSE +static const bool REUSE_SNAPSHOT = false; +#else +static const bool REUSE_SNAPSHOT = true; +#endif + +// If we use immutable data structures for symbolic states to reduce the cost of +// copying. +#ifdef USE_IMM +static const bool IMMUTABLE_SYMS = true; +#else +static const bool IMMUTABLE_SYMS = false; +#endif + +#ifdef INTERACTIVE +static const bool INTERACTIVE_MODE = true; +#else +static const bool INTERACTIVE_MODE = false; +#endif + +#ifdef USE_COST_MODEL +static const bool ENABLE_COST_MODEL = true; +#else +static const bool ENABLE_COST_MODEL = false; +#endif + +#endif // CONFIG_HPP \ No newline at end of file diff --git a/genwasym_runtime/include/wasm/controls.hpp b/genwasym_runtime/include/wasm/controls.hpp new file mode 100644 index 00000000..f3bbdf91 --- /dev/null +++ b/genwasym_runtime/include/wasm/controls.hpp @@ -0,0 +1,105 @@ + +#ifndef WASM_CONTROLS_HPP +#define WASM_CONTROLS_HPP + +#include +#include + +#include +#include +#include + +class MContRepr; +struct MCont_t { + std::shared_ptr ptr; + MCont_t() : ptr(nullptr) {} + MCont_t(const MCont_t &p) : ptr(p.ptr) {} + MCont_t(std::shared_ptr p) : ptr(p) {} + MCont_t(std::function haltK) + : ptr(std::make_shared(haltK)) { + assert(haltK); + } + bool is_null() const { return ptr == nullptr; } + + std::monostate enter(); +}; +using Cont_t = std::function; + +static MCont_t CURRENT_MCONT; + +inline std::monostate updateCurrentMCont(MCont_t newMCont) { + CURRENT_MCONT = newMCont; + return std::monostate{}; +} + +class MContRepr { + friend std::monostate enterCC(std::monostate); + +public: + MContRepr(Cont_t cont, MCont_t mcont) : cont(cont), mcont(mcont) {} + + MContRepr(std::function haltK) + : cont(haltK), mcont() {} + + // MContRepr() : cont(nullptr), mcont() {} + + std::monostate enter() { + // std::cout << "Entering MCont\n"; + // std::cout << "Cont cont: " << (cont ? "valid" : "null") << "\n"; + // std::cout << "MCont mcont: " << (mcont ? "valid" : "null") << "\n"; + + // This is necessary, because `this` may be deleted + // after next line. This copy is cheap because we always store a function + // pointer (non captured free variable lambda) in cont. + std::monostate (*func_ptr)(std::monostate) = nullptr; + { + auto cont = this->cont; + func_ptr = *cont.target(); + } + + CURRENT_MCONT = mcont; + + return func_ptr(std::monostate{}); + } + +private: + Cont_t cont; + MCont_t mcont; +}; + +inline MCont_t prependCont(Cont_t k, MCont_t mcont) { + return std::make_shared(k, mcont); +} + +inline std::monostate MCont_t::enter() { return ptr->enter(); } + +// Enter the current global MCont (CURRENT_MCONT) +inline std::monostate enterCC(std::monostate) { + // std::cout << "Entering MCont\n"; + // std::cout << "Cont cont: " << (cont ? "valid" : "null") << "\n"; + // std::cout << "MCont mcont: " << (mcont ? "valid" : "null") << "\n"; + + // This is necessary, because `this` may be deleted + // after next line. This copy is cheap because we always store a function + // pointer (non captured free variable lambda) in cont. + std::monostate (*func_ptr)(std::monostate) = nullptr; + { + auto cont = CURRENT_MCONT.ptr->cont; + func_ptr = *cont.target(); + } + + CURRENT_MCONT = CURRENT_MCONT.ptr->mcont; + + __attribute__((musttail)) return func_ptr(std::monostate{}); +} + +struct Control { + Cont_t cont; + MCont_t mcont; + + Control(Cont_t cont, MCont_t mcont) : cont(cont), mcont(mcont) {} +}; + +using Func_t = std::function; + +#endif // WASM_CONTROLS_HPP \ No newline at end of file diff --git a/genwasym_runtime/include/wasm/heap_mem_bookkeeper.hpp b/genwasym_runtime/include/wasm/heap_mem_bookkeeper.hpp new file mode 100644 index 00000000..c4cc7313 --- /dev/null +++ b/genwasym_runtime/include/wasm/heap_mem_bookkeeper.hpp @@ -0,0 +1,24 @@ +#ifndef HEAP_MEM_BOOKKEEPER_HPP +#define HEAP_MEM_BOOKKEEPER_HPP + +#include +#include + +// Todo: remove this later, this is just a workaround to make sure that the +// SymVals' memory will not be freed during the main execution. +// We can leave the SymVal's memory unmanaged if reference counting is not +// performant +template struct MemBookKeeper { + std::set> allocated; + + template + std::shared_ptr allocate(Args &&...args) { + auto ptr = std::make_shared(std::forward(args)...); + // allocated.insert(ptr); + return ptr; + } + + void clear() { allocated.clear(); } +}; + +#endif // HEAP_MEM_BOOKKEEPER_HPP \ No newline at end of file diff --git a/genwasym_runtime/include/wasm/output_report.hpp b/genwasym_runtime/include/wasm/output_report.hpp new file mode 100644 index 00000000..26bf5410 --- /dev/null +++ b/genwasym_runtime/include/wasm/output_report.hpp @@ -0,0 +1,50 @@ +#ifndef WASM_OUTPUT_REPORT_HPP +#define WASM_OUTPUT_REPORT_HPP + +#include "profile.hpp" +#include "sym_rt.hpp" +#include "config.hpp" +#include + +inline void dump_all_summary_json(const Profile_t &profile, + const OverallResult &overall) { + // use environment variable OUTPUT_FILE to config particular output profiling file + const char *output_file = std::getenv("OUTPUT_FILE"); + if (output_file == nullptr) { + return; + } + + std::filesystem::path report_path(output_file); + + auto parent = report_path.parent_path(); + if (!parent.empty()) { + std::error_code ec; + std::filesystem::create_directories(parent, ec); + if (ec) { + throw std::runtime_error("Failed to create output directory: " + + ec.message()); + } + } + + std::ofstream ofs(report_path); + if (!ofs.is_open()) { + throw std::runtime_error("Failed to open " + report_path.string() + + " for writing"); + } + + // Simple JSON dump (pretty-printed) + ofs << "{\n"; + ofs << " \"unexplored_count\": " << overall.unexplored_count << ",\n"; + ofs << " \"finished_count\": " << overall.finished_count << ",\n"; + ofs << " \"failed_count\": " << overall.failed_count << ",\n"; + ofs << " \"not_to_explore_count\": " << overall.not_to_explore_count + << ",\n"; + ofs << " \"unreachable_count\": " << overall.unreachable_count; + if (PROFILE_STEP || PROFILE_TIME) { + ofs << ",\n"; + profile.write_as_json(ofs); + } + ofs << "}\n"; + ofs.close(); +} +#endif // WASM_OUTPUT_REPORT_HPP \ No newline at end of file diff --git a/genwasym_runtime/include/wasm/profile.hpp b/genwasym_runtime/include/wasm/profile.hpp new file mode 100644 index 00000000..922b1d45 --- /dev/null +++ b/genwasym_runtime/include/wasm/profile.hpp @@ -0,0 +1,416 @@ +#ifndef PROFILE_HPP +#define PROFILE_HPP + +#include "config.hpp" +#include "utils.hpp" +#include "z3++.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include + +enum class StepProfileKind { + PUSH, + POP, + PEEK, + SHIFT, + SET, + GET, + BINARY, + TREE_FILL, + CURSOR_MOVE, + MEM_GROW, + SNAPSHOT_CREATE, + SYM_EVAL, + OperationCount // keep this as the last element, this is used to get the + // number of kinds of operations +}; + +enum class ExecutionKind { + RESTART, + FROMSNAPSHOT, + ExecutionKindCount // keep this as the last element, this is used to get the + // number of kinds of operations +}; + +enum class TimeProfileKind { + INSTR, + CALL_Z3_SOLVER, + MAKE_CONJUNCTION, + SOLVER_TOTAL, + RESUME_SNAPSHOT, + COUNT_SYM_SIZE, + SPLIT_CONDITIONS, + COLLECT_PATH_CONDITIONS, + MAIN_LOOP, + TimeOperationCount // keep this as the last element, this is used to get the + // number of kinds of operations +}; + +class Profile_t { +public: + Profile_t() : step_count(0), cache_hit_count(0), cache_miss_count(0) { + // refresh the output profile directory + if (PROFILE_Z3_API_CALL) { + std::filesystem::path out_path(base_profile_output_path); + std::error_code ec; + std::filesystem::remove_all(out_path, ec); + if (ec) { + throw std::runtime_error("Failed to clear output directory: " + + ec.message()); + } + std::filesystem::create_directories(out_path, ec); + if (ec) { + throw std::runtime_error("Failed to create output directory: " + + ec.message()); + } + std::string record_file = + base_profile_output_path + "/z3_solver_time_record.csv"; + std::ofstream ofs(record_file); + ofs << "Expression file,time spent (s),is_sat\n"; + ofs.close(); + std::filesystem::create_directories(z3_expr_output_path, ec); + if (ec) { + throw std::runtime_error("Failed to create z3 expr output directory: " + + ec.message()); + } + } + } + + void cache_hit() { + if (PROFILE_CACHE) + cache_hit_count++; + } + + void cache_miss() { + if (PROFILE_CACHE) + cache_miss_count++; + } + + std::monostate step() { + if (PROFILE_STEP) + step_count++; + return std::monostate(); + } + std::monostate step(StepProfileKind op) { + if (PROFILE_STEP) + op_count[static_cast(op)]++; + return std::monostate(); + } + std::monostate incr_restart_count() { + exec_kind_count[static_cast(ExecutionKind::RESTART)]++; + return std::monostate(); + } + std::monostate incr_fromsnapshot_count() { + exec_kind_count[static_cast(ExecutionKind::FROMSNAPSHOT)]++; + return std::monostate(); + } + std::monostate incr_call_solver_count() { + call_solver_count++; + return std::monostate(); + } + + void print_summary() { + if (PROFILE_STEP) { + std::cout << "Profile Summary:" << std::endl; + std::cout << "Total PUSH operations: " + << op_count[static_cast(StepProfileKind::PUSH)] + << std::endl; + std::cout << "Total POP operations: " + << op_count[static_cast(StepProfileKind::POP)] + << std::endl; + std::cout << "Total PEEK operations: " + << op_count[static_cast(StepProfileKind::PEEK)] + << std::endl; + std::cout << "Total SHIFT operations: " + << op_count[static_cast(StepProfileKind::SHIFT)] + << std::endl; + std::cout << "Total SET operations: " + << op_count[static_cast(StepProfileKind::SET)] + << std::endl; + std::cout << "Total GET operations: " + << op_count[static_cast(StepProfileKind::GET)] + << std::endl; + std::cout << "Total BINARY operations: " + << op_count[static_cast(StepProfileKind::BINARY)] + << std::endl; + std::cout + << "Total TREE_FILL operations: " + << op_count[static_cast(StepProfileKind::TREE_FILL)] + << std::endl; + std::cout + << "Total CURSOR_MOVE operations: " + << op_count[static_cast(StepProfileKind::CURSOR_MOVE)] + << std::endl; + std::cout << "Total other instructions executed: " << step_count + << std::endl; + std::cout << "Total MEM_GROW operations: " + << op_count[static_cast(StepProfileKind::MEM_GROW)] + << std::endl; + std::cout << "Total SNAPSHOT_CREATE operations: " + << op_count[static_cast( + StepProfileKind::SNAPSHOT_CREATE)] + << std::endl; + std::cout << "Total SYM_EVAL operations: " + << op_count[static_cast(StepProfileKind::SYM_EVAL)] + << std::endl; + } + if (PROFILE_TIME) { + std::cout << "Time Profile Summary:" << std::endl; + std::cout << "Total time in instruction execution (s): " + << std::setprecision(15) + << time_count[static_cast(TimeProfileKind::INSTR)] + << std::endl; + std::cout + << "Total time in solver (s): " << std::setprecision(15) + << time_count[static_cast(TimeProfileKind::SOLVER_TOTAL)] + << std::endl; + std::cout << "Total time in z3 api call (s): " << std::setprecision(15) + << time_count[static_cast( + TimeProfileKind::CALL_Z3_SOLVER)] + << std::endl; + std::cout << "Total time in resuming from snapshot (s): " + << std::setprecision(15) + << time_count[static_cast( + TimeProfileKind::RESUME_SNAPSHOT)] + << std::endl; + std::cout << "Total time in counting symbolic size (s): " + << std::setprecision(15) + << time_count[static_cast( + TimeProfileKind::COUNT_SYM_SIZE)] + << std::endl; + std::cout << "Total time in splitting path conditions (s): " + << std::setprecision(15) + << time_count[static_cast( + TimeProfileKind::SPLIT_CONDITIONS)] + << std::endl; + std::cout << "Total time in collecting path conditions (s): " + << std::setprecision(15) + << time_count[static_cast( + TimeProfileKind::COLLECT_PATH_CONDITIONS)] + << std::endl; + std::cout + << "Total time in main loop (s): " << std::setprecision(15) + << time_count[static_cast(TimeProfileKind::MAIN_LOOP)] + << std::endl; + } + if (PROFILE_CACHE) { + std::cout << "Solver Cache Summary:" << std::endl; + std::cout << "Total cache hits: " << cache_hit_count << std::endl; + std::cout << "Total cache misses: " << cache_miss_count << std::endl; + std::cout << "Time of making conjunctions (s): " << std::setprecision(15) + << time_count[static_cast( + TimeProfileKind::MAKE_CONJUNCTION)] + << std::endl; + std::cout << "Cache hit rate: " + << static_cast(cache_hit_count) / + static_cast(cache_hit_count + cache_miss_count) + << std::endl; + } + if (PROFILE_PATH_CONDS) { + std::cout << "Path Conditions Profile Summary:" << std::endl; + std::cout << "Total time in collecting path conditions (s): " + << std::setprecision(15) + << time_count[static_cast( + TimeProfileKind::COLLECT_PATH_CONDITIONS)] + << std::endl; + } + std::cout << "Number of calls to solver: " << call_solver_count + << std::endl; + std::cout << "Execution Kind Summary:" << std::endl; + std::cout + << "Total RESTART executions: " + << exec_kind_count[static_cast(ExecutionKind::RESTART)] + << std::endl; + std::cout << "Total FROMSNAPSHOT executions: " + << exec_kind_count[static_cast( + ExecutionKind::FROMSNAPSHOT)] + << std::endl; + } + + void write_as_json(std::ostream &os) const { + os << " \"profile_summary\": {\n"; + if (PROFILE_STEP) { + os << " \"total_push_operations\": " + << op_count[static_cast(StepProfileKind::PUSH)] << ",\n"; + os << " \"total_pop_operations\": " + << op_count[static_cast(StepProfileKind::POP)] << ",\n"; + os << " \"total_peek_operations\": " + << op_count[static_cast(StepProfileKind::PEEK)] << ",\n"; + os << " \"total_shift_operations\": " + << op_count[static_cast(StepProfileKind::SHIFT)] << ",\n"; + os << " \"total_set_operations\": " + << op_count[static_cast(StepProfileKind::SET)] << ",\n"; + os << " \"total_get_operations\": " + << op_count[static_cast(StepProfileKind::GET)] << ",\n"; + os << " \"total_binary_operations\": " + << op_count[static_cast(StepProfileKind::BINARY)] + << ",\n"; + os << " \"total_tree_fill_operations\": " + << op_count[static_cast(StepProfileKind::TREE_FILL)] + << ",\n"; + os << " \"total_cursor_move_operations\": " + << op_count[static_cast(StepProfileKind::CURSOR_MOVE)] + << ",\n"; + os << " \"total_other_instructions_executed\": " << step_count + << ",\n"; + os << " \"total_mem_grow_operations\": " + << op_count[static_cast(StepProfileKind::MEM_GROW)] + << ",\n"; + os << " \"total_snapshot_create_operations\": " + << op_count[static_cast(StepProfileKind::SNAPSHOT_CREATE)] + << ",\n"; + os << " \"total_sym_eval_operations\": " + << op_count[static_cast(StepProfileKind::SYM_EVAL)] + << "\n"; + } + if (PROFILE_TIME) { + os << " \"total_time_instruction_execution_s\": " + << std::setprecision(15) + << time_count[static_cast(TimeProfileKind::INSTR)] + << ",\n"; + os << " \"total_time_solver_s\": " << std::setprecision(15) + << time_count[static_cast( + TimeProfileKind::CALL_Z3_SOLVER)] + << ",\n"; + os << " \"total_time_resuming_from_snapshot_s\": " + << std::setprecision(15) + << time_count[static_cast( + TimeProfileKind::RESUME_SNAPSHOT)] + << ",\n"; + os << " \"total_time_counting_symbolic_size_s\": " + << std::setprecision(15) + << time_count[static_cast( + TimeProfileKind::COUNT_SYM_SIZE)] + << "\n"; + os << " \"total_time_splitting_path_conditions_s\": " + << std::setprecision(15) + << time_count[static_cast( + TimeProfileKind::SPLIT_CONDITIONS)] + << ",\n"; + } + if (PROFILE_CACHE) { + os << " \"total_cache_hits\": " << cache_hit_count << ",\n"; + os << " \"total_cache_misses\": " << cache_miss_count << ",\n"; + os << " \"cache_hit_rate\": " + << static_cast(cache_hit_count) / + static_cast(cache_hit_count + cache_miss_count) + << "\n"; + } + os << " }\n"; + } + + void record_z3_solver_time(z3::solver expr, double time, bool is_sat) { + // Write z3 expression in a file, and write the time spent in solving it and + // the file path in another file + if (PROFILE_Z3_API_CALL) { + static int count = 0; + std::string expr_file = + z3_expr_output_path + "/z3_expr_" + std::to_string(count) + ".smt2"; + std::error_code ec; + std::ofstream ofs(expr_file); + ofs << expr; + ofs.close(); + std::string record_file = + base_profile_output_path + "/z3_solver_time_record.csv"; + std::ofstream rofs(record_file, std::ios::app); + rofs << expr_file << "," << std::setprecision(15) << time << "," + << (is_sat ? "sat" : "unsat") << "\n"; + rofs.close(); + count++; + } + } + + std::string base_profile_output_path = "genwasym_profile_output"; + std::string z3_expr_output_path = "genwasym_profile_output/z3_expressions"; + + // record the time spent in main instruction execution, in seconds + void add_instruction_time(TimeProfileKind kind, double time) { + time_count[static_cast(kind)] += time; + } + + void remove_instruction_time(TimeProfileKind kind, double time) { + time_count[static_cast(kind)] -= time; + } + + int step_count; + std::array(StepProfileKind::OperationCount)> + op_count; + std::array(TimeProfileKind::TimeOperationCount)> + time_count; + std::array(ExecutionKind::ExecutionKindCount)> + exec_kind_count; + + int cache_hit_count; + int cache_miss_count; + int call_solver_count; +}; + +static Profile_t Profile; + +class ManagedTimer { +public: + ManagedTimer() = delete; + ManagedTimer(TimeProfileKind kind) : kind(kind), time_ref(nullptr) { + start = std::chrono::high_resolution_clock::now(); + } + ManagedTimer(TimeProfileKind kind, double &time_ref) + : kind(kind), time_ref(&time_ref) { + start = std::chrono::high_resolution_clock::now(); + } + ~ManagedTimer() { + auto end = std::chrono::high_resolution_clock::now(); + std::chrono::duration elapsed = end - start; + Profile.add_instruction_time(kind, elapsed.count()); + if (time_ref != nullptr) { + *time_ref += elapsed.count(); + } + } + +private: + TimeProfileKind kind; + std::chrono::high_resolution_clock::time_point start; + double *time_ref; +}; + +using Time = std::chrono::time_point; + +inline Time getCurrentTime() { return std::chrono::steady_clock::now(); } + +inline double duration_time(Time start, Time end) { + std::chrono::duration duration = end - start; + return duration.count(); +} + +struct CostManager_t { + Time start_time; + + CostManager_t() : start_time() {} + + std::monostate reset_timer() { + start_time = getCurrentTime(); + return std::monostate(); + } + + double dump_instr_cost() { + auto current = getCurrentTime(); + double duration = duration_time(start_time, current); + reset_timer(); + return normalize_cost(duration); + } + + double normalize_cost(double cost) { + // Just return duration time as it is + return 1 * cost; + } +}; + +static CostManager_t CostManager; + +#endif // PROFILE_HPP \ No newline at end of file diff --git a/genwasym_runtime/include/wasm/smt_solver.hpp b/genwasym_runtime/include/wasm/smt_solver.hpp new file mode 100644 index 00000000..562eeecb --- /dev/null +++ b/genwasym_runtime/include/wasm/smt_solver.hpp @@ -0,0 +1,459 @@ +#ifndef SMT_SOLVER_HPP +#define SMT_SOLVER_HPP + +#include "concrete_rt.hpp" +#include "sym_rt.hpp" +#include "union_find.hpp" +#include "utils.hpp" +#include "wasm/profile.hpp" +#include "wasm/symbolic_decl.hpp" +#include "z3++.h" +#include "z3_env.hpp" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +struct QueryResult { + ImmNumMapBox map_box; + z3::model model; +}; + +struct QueryResultWithWitness : public QueryResult { + QueryResultWithWitness(ImmNumMapBox map_box, z3::model model, + NodeBox *witness) + : QueryResult{map_box, model}, witness(witness) {} + NodeBox *witness; +}; + +static QueryResult +compose_query_results(const std::vector &results) { + ManagedTimer timer(TimeProfileKind::SPLIT_CONDITIONS); + NumMap combined_map; + z3::model combined_model(global_z3_ctx()); + for (const auto &res : results) { + auto num_map = res.map_box; + for (const auto &[id, num] : *num_map) { + assert( + combined_map.find(id) == combined_map.end() && + "Conflicting symbolic environment ids when composing query results"); + combined_map[id] = num; + } + const z3::model &model = res.model; + for (unsigned i = 0; i < model.num_consts(); ++i) { + z3::func_decl decl = model.get_const_decl(i); + std::string name = decl.name().str(); + assert((starts_with(name, "s_int") || starts_with(name, "s_f32") || + starts_with(name, "s_f64")) && + "Unexpected declaration in query result model"); + assert(!combined_model.has_interp(decl) && + "Internal Error: Conflicting constant declarations when composing query results"); + z3::expr value = model.get_const_interp(decl); + combined_model.add_const_interp(decl, value); + } + } + ImmNumMapBox combined_map_box(combined_map); + return QueryResult{combined_map_box, combined_model}; +} + +// VectorGroupResult groups a vector. key is the vector index, and value is the +// group id, ungrouped items do not have group id +using VectorGroupMap = std::unordered_map; + +struct GroupResult { + std::unordered_map> conds_in_groups; + std::vector ungrouped_conds; +}; + +static std::optional group_of_symval(const SymVal &sym, UnionFind &uf) { + // TODO: This process is un optimized and slow, just want to see if the idea + // of independent resolving works + if (auto symbol = dynamic_cast(sym.symptr.get())) { + return symbol->get_id(); + } else if (auto concrete = dynamic_cast(sym.symptr.get())) { + return std::nullopt; + } else if (auto binary = dynamic_cast(sym.symptr.get())) { + auto left_group = group_of_symval(binary->lhs, uf); + auto right_group = group_of_symval(binary->rhs, uf); + if (left_group.has_value() && right_group.has_value()) { + uf.unite(*left_group, *right_group); + return uf.find(*left_group); + } else if (left_group.has_value()) { + return uf.find(*left_group); + } else if (right_group.has_value()) { + return uf.find(*right_group); + } else { + return std::nullopt; + } + } else if (auto unary = dynamic_cast(sym.symptr.get())) { + return group_of_symval(unary->value, uf); + } else if (auto extract = dynamic_cast(sym.symptr.get())) { + return group_of_symval(extract->value, uf); + } + return std::nullopt; +} + +static VectorGroupMap build_group_map(const std::vector &conditions) { + // TODO: This is a slow temporary solution which only used for validating the + // idea of independent constraint resolving, the intermediate result of + // independent solving is reusable + ManagedTimer timer(TimeProfileKind::SPLIT_CONDITIONS); + if (conditions.empty()) { + return VectorGroupMap{}; + } + // use union find to group the conditions + UnionFind uf; + for (const auto &cond : conditions) { + group_of_symval(cond, uf); + } + + VectorGroupMap result; + + for (size_t i = 0; i < conditions.size(); ++i) { + auto group = group_of_symval(conditions[i], uf); + if (group.has_value()) { + int group_id = uf.find(*group); + result[i] = group_id; + } + } + return result; +} + +static struct GroupResult +split_conditions(const std::vector &conditions, + const VectorGroupMap &group_map, + const std::unordered_set unused_indexes = {}) { + ManagedTimer timer(TimeProfileKind::SPLIT_CONDITIONS); + std::vector ungrouped_conds; + std::unordered_map> conds_in_groups; + for (size_t i = 0; i < conditions.size(); ++i) { + auto cond = conditions[i]; + if (group_map.find(i) != group_map.end()) { + int group_id = group_map.at(i); + conds_in_groups[group_id].push_back(cond); + } else { + if (unused_indexes.find(i) == unused_indexes.end()) { + ungrouped_conds.push_back(cond); + } + } + } + return GroupResult{conds_in_groups, ungrouped_conds}; +} + +class Solver { +public: + Solver() {} + + // Solve the path conditions. `only_latest_unseen` indicates whether the + // previous conditions and the negation of the latest condition have been + // reached before + std::optional solve_path_conds(std::vector &conditions, + bool only_latest_unseen) { + if (conditions.empty()) { + return QueryResult{ImmNumMapBox(NumMap{}), z3::model(global_z3_ctx())}; + } + + // split the conditions into independent groups + auto group_map = build_group_map(conditions); + + if (only_latest_unseen) { + auto latest_pc_index = 0; + if (group_map.find(latest_pc_index) == group_map.end()) { + // the latest path condition is pure concrete, it must be a + // unsatisfiable concrete condition, because its negation has been + // executed before + return std::nullopt; + } + } + + GroupResult groups = split_conditions(conditions, group_map); + + if (only_latest_unseen) { + // We can safely remove all previously seen pure concrete conditions. + // They are satisfiable and contribute nothing to model. + groups.ungrouped_conds.clear(); + } + + return solve_by_groups(groups, group_map, conditions.size()); + } + + std::optional solve(const std::vector &conditions) { + if (conditions.empty()) { + return QueryResult{ImmNumMapBox(NumMap{}), z3::model(global_z3_ctx())}; + } + + // split the conditions into independent groups + VectorGroupMap group_map = build_group_map(conditions); + GroupResult groups = split_conditions(conditions, group_map); + + return solve_by_groups(groups, group_map, conditions.size()); + } + + std::optional + solve_under_reachable_path(std::vector &&conditions, + SymVal extra_cond) { + conditions.push_back(extra_cond); + VectorGroupMap group_map = build_group_map(conditions); + std::unordered_set unused_indexes; + for (size_t i = 0; i < conditions.size() - 1; ++i) { + if (group_map.find(i) == group_map.end()) { + // this condition is pure concrete, and has been executed before, it + // must be satisfiable and can be ignored + unused_indexes.insert(i); + } + } + GroupResult groups = + split_conditions(conditions, group_map, unused_indexes); + return solve_by_groups(groups, group_map, conditions.size()); + } + + std::optional find_reachable_path_with_witness( + const std::vector> &all_conditions, + const std::vector &candidate_nodes) { + assert(all_conditions.size() == candidate_nodes.size() && + "Conditions size and candidate nodes size must be equal"); + std::vector disjuncts; + auto witness = SymVal::get_witness_symbol(); + SymVal disjunction = SVFactory::FALSE; + { + ManagedTimer timer(TimeProfileKind::COLLECT_PATH_CONDITIONS); + for (size_t i = 0; i < all_conditions.size(); ++i) { + const auto &conds = all_conditions[i]; + auto clause = make_conjunction(conds, true); + clause = clause.land( + witness.eq_bool(SVFactory::make_concrete_bv(Num(i), 32))); + + disjuncts.push_back(clause); + } + disjunction = make_disjunction(disjuncts); + } + + auto result = solve_group({disjunction}, false); + if (!result.has_value()) { + return std::nullopt; + } + z3::model &model = result->model; + // find which clause in disjunct is satisfied + z3::expr witness_expr = model.eval(witness->z3_expr(), true); + int witness_index = witness_expr.get_numeral_int64(); + + return QueryResultWithWitness{ + result->map_box, + result->model, + candidate_nodes[witness_index], + }; + } + +private: + std::optional solve_group(const std::vector &conditions, + bool is_bv) { + + z3::solver z3_solver(global_z3_ctx()); + SymVal conjunction = SVFactory::TRUE; + z3::check_result solver_result; + double z3_solver_time = 0.0; + { + auto timer = + ManagedTimer(TimeProfileKind::CALL_Z3_SOLVER, z3_solver_time); + Profile.incr_call_solver_count(); + // make an conjunction of all conditions + conjunction = make_conjunction(conditions, is_bv); + // call z3 to solve the condition + if (auto it = solver_cache.find(conjunction); it != solver_cache.end()) { + Profile.cache_hit(); + return it->second; + } + Profile.cache_miss(); + SymValSet added_conds; + for (size_t i = 0; i < conditions.size(); ++i) { + SymVal temp = is_bv ? conditions[i].bv2bool() : conditions[i]; + if (added_conds.find(temp) != added_conds.end()) { + continue; + } + z3_solver.add(temp->z3_expr()); + added_conds.insert(temp); + } + GENSYM_INFO("Solving conditions with Z3 solver..."); + solver_result = z3_solver.check(); + } + Profile.record_z3_solver_time(z3_solver, z3_solver_time, + solver_result == z3::sat); + switch (solver_result) { + case z3::unsat: + solver_cache[conjunction] = std::nullopt; + return std::nullopt; // No solution found + case z3::sat: { + z3::model model = z3_solver.get_model(); + NumMap result; + // Reference: + // https://github.com/Z3Prover/z3/blob/master/examples/c%2B%2B/example.cpp#L59 + GENSYM_INFO("Solved Z3 model"); + GENSYM_INFO(model); + for (unsigned i = 0; i < model.size(); ++i) { + z3::func_decl var = model[i]; + z3::expr value = model.get_const_interp(var); + std::string name = var.name().str(); + if (starts_with(name, "s_int")) { + int id = std::stoi(name.substr(std::string("s_int").length())); + z3::expr evaluated = model.eval(value, true); + uint64_t bits = evaluated.get_numeral_uint64(); + int64_t raw = 0; + std::memcpy(&raw, &bits, sizeof(raw)); + result[id] = Num(raw); + } else if (starts_with(name, "s_f32")) { + int id = std::stoi(name.substr(std::string("s_f32").length())); + z3::expr evaluated = model.eval(value.mk_to_ieee_bv(), true); + uint64_t bits = evaluated.get_numeral_uint64(); + result[id] = Num(static_cast(static_cast(bits))); + } else if (starts_with(name, "s_f64")) { + int id = std::stoi(name.substr(std::string("s_f64").length())); + z3::expr evaluated = model.eval(value.mk_to_ieee_bv(), true); + uint64_t bits = evaluated.get_numeral_uint64(); + int64_t raw = 0; + std::memcpy(&raw, &bits, sizeof(raw)); + result[id] = Num(raw); + } else { + GENSYM_INFO("Find a variable that is not created by GenSym: " + name); + } + } + ImmNumMapBox map_box(result); + QueryResult query_result{map_box, model}; + solver_cache[conjunction] = query_result; + return query_result; + } + case z3::unknown: + throw std::runtime_error("Z3 solver returned unknown status"); + } + return std::nullopt; // Should not reach here + } + + std::optional solve_by_groups(const GroupResult &groups, + const VectorGroupMap &group_map, + int condition_size) { + + if (!solve_group(groups.ungrouped_conds, true).has_value()) { + return std::nullopt; + } + + std::vector group_results; + std::unordered_set processed_groups; + for (size_t i = 0; i < condition_size; ++i) { + if (group_map.find(i) == group_map.end()) { + // ungrouped condition, skip it + continue; + } + int group_id = group_map.at(i); + if (processed_groups.find(group_id) != processed_groups.end()) { + // already processed + continue; + } + processed_groups.insert(group_id); + auto &group_conds = groups.conds_in_groups.at(group_id); + auto group_result = solve_group(group_conds, true); + if (!group_result.has_value()) { + // this group is unsatisfiable, so the whole condition is + // unsatisfiable + return std::nullopt; + } + group_results.push_back(group_result.value()); + } + + // combine the results from all groups + return compose_query_results(group_results); + } + + // make a big conjunction from a list of bitvector symbolic values + SymVal make_conjunction(const std::vector &conditions, bool is_bv) { + ManagedTimer timer(TimeProfileKind::MAKE_CONJUNCTION); + SymVal result = SVFactory::make_concrete_bool(true); // true + SymValSet added_conds; + for (size_t i = 0; i < conditions.size(); ++i) { + SymVal temp = is_bv ? conditions[i].bv2bool() : conditions[i]; + if (added_conds.find(temp) != added_conds.end()) { + continue; + } + added_conds.insert(temp); + result = result.land(temp); + } + return result; + } + + // make a big disjunction from a list of bool symbolic values + SymVal make_disjunction(const std::vector &conditions) { + SymVal fls = SVFactory::make_concrete_bool(false); // false + SymVal result = fls; + SymValSet added_conds; + for (size_t i = 0; i < conditions.size(); ++i) { + if (added_conds.find(conditions[i]) != added_conds.end()) { + continue; + } + added_conds.insert(conditions[i]); + result = result.lor(conditions[i]); + } + return result; + } + + z3::expr to_z3_conjunction(std::vector &conditions) { + z3::expr conjunction = global_z3_ctx().bool_val(true); + for (auto &cond : conditions) { + auto z3_cond = cond->z3_expr(); + conjunction = conjunction && z3_cond != global_z3_ctx().bv_val(0, 32); + } +#ifdef DEBUG + // std::cout << "Symbolic conditions size: " << conditions.size() << + // std::endl; std::cout << "Solving conditions: " << conjunction << + // std::endl; +#endif + return conjunction; + } + + SymValMap> solver_cache; +}; + +static Solver solver; + +inline EvalRes eval_sym_expr_by_model(const SymVal &sym, z3::model &model) { + auto expr = sym->z3_expr(); + // let z3 decide the value of symbols that are not in the model + // every value is bitvector + switch (sym->value_kind()) { + case KindBV: { + z3::expr value = model.eval(expr, true); + int width = expr.get_sort().bv_size(); + return EvalRes(Num(value.get_numeral_int64()), width, KindBV); + } + case KindBool: { + assert(false && "unreachable"); + } + case KindFP: { + z3::expr value = model.eval(expr.mk_to_ieee_bv(), true); + int width = get_z3_fp_sort_size(expr.get_sort()); + return EvalRes(Num(value.get_numeral_int64()), width, KindFP); + } + } +} + +inline std::monostate GENSYM_SYM_ASSERT(SymVal &sym_cond) { + ManagedTimer timer(TimeProfileKind::SOLVER_TOTAL); + auto start = std::chrono::steady_clock::now(); + std::vector conds = ExploreTree.collect_current_path_conds(); + auto result = solver.solve_under_reachable_path( + std::move(conds), sym_cond.bv_negate().bool2bv()); + auto end = std::chrono::steady_clock::now(); + auto time_need_to_be_removed = std::chrono::duration(end - start); + Profile.remove_instruction_time(TimeProfileKind::INSTR, + time_need_to_be_removed.count()); + if (result.has_value()) { + std::cout << "Symbolic assertion failed" << std::endl; + throw std::runtime_error("Symbolic assertion failed"); + } + return std::monostate{}; +} + +#endif // SMT_SOLVER_HPP diff --git a/genwasym_runtime/include/wasm/sym_rt.hpp b/genwasym_runtime/include/wasm/sym_rt.hpp new file mode 100644 index 00000000..760112b5 --- /dev/null +++ b/genwasym_runtime/include/wasm/sym_rt.hpp @@ -0,0 +1,1926 @@ +#ifndef WASM_SYMBOLIC_RT_HPP +#define WASM_SYMBOLIC_RT_HPP + +#include "concrete_rt.hpp" +#include "config.hpp" +#include "controls.hpp" +#include "heap_mem_bookkeeper.hpp" +#include "immer/map.hpp" +#include "immer/map_transient.hpp" +#include "immer/vector.hpp" +#include "immer/vector_transient.hpp" +#include "profile.hpp" +#include "symbolic_decl.hpp" +#include "symbolic_impl.hpp" +#include "symval_decl.hpp" +#include "symval_factory.hpp" +#include "symval_impl.hpp" +#include "utils.hpp" +#include "wasm/concrete_num.hpp" +#include "wasm/z3_env.hpp" +#include "z3++.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +class Snapshot_t; + +class SymStack_t { +public: + void push(SymVal val) { + // Push a symbolic value to the stack + stack.push_back(val); + } + + SymVal pop() { + // Pop a symbolic value from the stack +#ifdef DEBUG + printf("[Debug] poping from stack, size of symbolic stack is: %zu\n", + stack.size()); +#endif +#ifdef USE_IMM + auto ret = *(stack.end() - 1); + stack.take(stack.size() - 1); + return ret; +#else + auto ret = stack.back(); + stack.pop_back(); + return ret; +#endif + } + + SymVal peek() { return *(stack.end() - 1); } + + std::monostate shift(int32_t offset, int32_t size) { + auto n = stack.size(); + for (size_t i = n - size; i < n; ++i) { + assert(i - offset >= 0); +#ifdef USE_IMM + stack.set(i - offset, stack[i]); +#else + stack[i - offset] = stack[i]; +#endif + } +#ifdef USE_IMM + stack.take(n - offset); +#else + stack.erase(stack.begin() + (n - offset), stack.end()); +#endif + return std::monostate(); + } + + void reset() { +// Reset the symbolic stack +#ifdef USE_IMM + stack = immer::vector_transient(); +#else + stack.clear(); +#endif + symbolic_size = 0; + } + + size_t size() const { return stack.size(); } + + SymVal operator[](size_t index) const { return stack[index]; } + + int total_sym_size() const { + ManagedTimer timer(TimeProfileKind::COUNT_SYM_SIZE); + int total_size = 0; + for (const auto &val : stack) { + // std::cout << "Symbolic Expression: " << val->z3_expr() << "\n"; + // std::cout << "Val size: " << val.size() << "\n"; + total_size += val->size(); + } + return total_size; + } + +private: + int symbolic_size = 0; +#ifdef USE_IMM + immer::vector_transient stack; +#else + std::vector stack; +#endif +}; + +static SymStack_t SymStack; + +class SymFrames_t { + +public: + void restore_frame_ptr(Frames_t &frame) const; + + void pushFramePtr() { +#ifdef USE_IMM + frame_ptrs.push_back(stack.size()); +#else + frame_ptrs.push_back(stack.size()); +#endif + } + + void pushFrameSlot(int width) { +#ifdef USE_IMM + stack.push_back(SVFactory::make_concrete_bv(I64V(0), width)); +#else + stack.emplace_back(SVFactory::make_concrete_bv(I64V(0), width)); +#endif + } + + std::monostate popFrameCaller(int size) { + assert(size >= 0); + assert(static_cast(size) <= stack.size()); + assert(!frame_ptrs.empty()); + auto frame_base = current_frame_base(); + assert(frame_base + size == stack.size()); + +#ifdef USE_IMM + stack.take(stack.size() - size); +#else + stack.erase(stack.end() - size, stack.end()); +#endif + +#ifdef USE_IMM + frame_ptrs.take(frame_ptrs.size() - 1); +#else + frame_ptrs.pop_back(); +#endif + + return std::monostate{}; + } + + std::monostate popFrameCallee(int size) { + // Pop the frame of the given size + assert(size >= 0); + assert(static_cast(size) <= stack.size()); + +#ifdef USE_IMM + stack.take(stack.size() - size); +#else + stack.erase(stack.end() - size, stack.end()); +#endif + + return std::monostate{}; + } + + SymVal get(int index) { + // Get the symbolic value at the given frame index + assert(!frame_ptrs.empty()); + auto frame_base = current_frame_base(); + assert(index >= 0 && + static_cast(frame_base + index) < stack.size()); + auto res = stack[frame_base + index]; + return res; + } + + void set(int index, SymVal val) { + // Set the symbolic value at the given index + assert(val.symptr != nullptr); + assert(!frame_ptrs.empty()); + auto frame_base = current_frame_base(); + assert(index >= 0 && + static_cast(frame_base + index) < stack.size()); +#ifdef USE_IMM + stack.set(frame_base + index, val); +#else + stack[frame_base + index] = val; +#endif + } + + void reset() { + // Reset the symbolic frames + +#ifdef USE_IMM + stack = immer::vector_transient(); + frame_ptrs = immer::vector_transient(); +#else + stack.clear(); + frame_ptrs.clear(); +#endif + symbolic_size = 0; + } + + size_t size() const { return stack.size(); } + + SymVal operator[](size_t index) const { return stack[index]; } + + int total_sym_size() const { + ManagedTimer timer(TimeProfileKind::COUNT_SYM_SIZE); + int total_size = 0; + for (const auto &val : stack) { + total_size += val->size(); + } + return total_size; + } + +private: + size_t current_frame_base() const { +#ifdef USE_IMM + return *(frame_ptrs.end() - 1); +#else + return frame_ptrs.back(); +#endif + } + + int symbolic_size = 0; +#ifdef USE_IMM + immer::vector_transient frame_ptrs; + immer::vector_transient stack; +#else + std::vector frame_ptrs; + std::vector stack; +#endif +}; + +struct NodeBox; +struct SymEnv_t; + +class SymMemory_t { +public: +#ifdef USE_IMM + immer::map_transient memory; +#else + std::unordered_map memory; +#endif + int symbolic_size = 0; + + SymVal loadSymByte(int32_t addr) { +// if the address is not in the memory, it must be a zero-initialized memory +#ifdef USE_IMM + auto it = memory.find(addr); + if (it != nullptr) { + return *it; + } else { + auto s = SVFactory::ZeroByte; + return s; + } +#else + auto it = memory.find(addr); + SymVal s = (it != memory.end()) ? it->second : SVFactory::ZeroByte; + return s; +#endif + } + + SymVal loadSym(int32_t base, int32_t offset) { + // calculate the real address + +#ifdef USE_IMM + int32_t addr = base + offset; + auto it = memory.find(addr); + SymVal s0 = it ? *it : SVFactory::ZeroByte; + it = memory.find(addr + 1); + SymVal s1 = it ? *it : SVFactory::ZeroByte; + it = memory.find(addr + 2); + SymVal s2 = it ? *it : SVFactory::ZeroByte; + it = memory.find(addr + 3); + SymVal s3 = it ? *it : SVFactory::ZeroByte; + + return s3.concat(s2).concat(s1).concat(s0); +#else + int32_t addr = base + offset; + auto it = memory.find(addr); + SymVal s0 = (it != memory.end()) ? it->second : SVFactory::ZeroByte; + it = memory.find(addr + 1); + SymVal s1 = (it != memory.end()) ? it->second : SVFactory::ZeroByte; + it = memory.find(addr + 2); + SymVal s2 = (it != memory.end()) ? it->second : SVFactory::ZeroByte; + it = memory.find(addr + 3); + SymVal s3 = (it != memory.end()) ? it->second : SVFactory::ZeroByte; + + return s3.concat(s2).concat(s1).concat(s0); +#endif + } + + SymVal loadSymLong(int32_t base, int32_t offset) { +#ifdef USE_IMM + int32_t addr = base + offset; + auto it = memory.find(addr); + SymVal s0 = it ? *it : SVFactory::ZeroByte; + it = memory.find(addr + 1); + SymVal s1 = it ? *it : SVFactory::ZeroByte; + it = memory.find(addr + 2); + SymVal s2 = it ? *it : SVFactory::ZeroByte; + it = memory.find(addr + 3); + SymVal s3 = it ? *it : SVFactory::ZeroByte; + it = memory.find(addr + 4); + SymVal s4 = it ? *it : SVFactory::ZeroByte; + it = memory.find(addr + 5); + SymVal s5 = it ? *it : SVFactory::ZeroByte; + it = memory.find(addr + 6); + SymVal s6 = it ? *it : SVFactory::ZeroByte; + it = memory.find(addr + 7); + SymVal s7 = it ? *it : SVFactory::ZeroByte; +#else + int32_t addr = base + offset; + auto it = memory.find(addr); + SymVal s0 = (it != memory.end()) ? it->second : SVFactory::ZeroByte; + it = memory.find(addr + 1); + SymVal s1 = (it != memory.end()) ? it->second : SVFactory::ZeroByte; + it = memory.find(addr + 2); + SymVal s2 = (it != memory.end()) ? it->second : SVFactory::ZeroByte; + it = memory.find(addr + 3); + SymVal s3 = (it != memory.end()) ? it->second : SVFactory::ZeroByte; + it = memory.find(addr + 4); + SymVal s4 = (it != memory.end()) ? it->second : SVFactory::ZeroByte; + it = memory.find(addr + 5); + SymVal s5 = (it != memory.end()) ? it->second : SVFactory::ZeroByte; + it = memory.find(addr + 6); + SymVal s6 = (it != memory.end()) ? it->second : SVFactory::ZeroByte; + it = memory.find(addr + 7); + SymVal s7 = (it != memory.end()) ? it->second : SVFactory::ZeroByte; +#endif + + return s7.concat(s6) + .concat(s5) + .concat(s4) + .concat(s3) + .concat(s2) + .concat(s1) + .concat(s0); + } + + SymVal loadSymFloat(int32_t base, int32_t offset) { + // For simplicity, we treat float as concrete value for now + auto symbv = loadSym(base, offset); + assert(symbv.is_concrete() && "Currently only support concrete symbolic " + "value for float-point values"); + if (auto concrete = dynamic_cast(symbv.symptr.get())) { + auto value = concrete->value; + return SVFactory::make_concrete_fp(value, 32); + } else { + assert(false && "unreachable"); + } + } + + SymVal loadSymDouble(int32_t base, int32_t offset) { + // For simplicity, we treat double as concrete value for now + auto symbv = loadSymLong(base, offset); + assert(symbv.is_concrete() && "Currently only support concrete symbolic " + "value for float-point values"); + if (auto concrete = dynamic_cast(symbv.symptr.get())) { + auto value = concrete->value; + return SVFactory::make_concrete_fp(value, 64); + } else { + assert(false && "unreachable"); + } + } + + SymVal loadSymInt8U(int32_t base, int32_t offset) { + return SVFactory::make_smallbv(24, 0).concat(loadSymByte(base + offset)); + } + + SymVal loadSymInt8S(int32_t base, int32_t offset) { + auto value = loadSymInt8U(base, offset); + auto shift = SVFactory::make_concrete_bv(I32V(24), 32); + return value.shl(shift).shr_s(shift); + } + + SymVal loadSymInt16U(int32_t base, int32_t offset) { + auto low = loadSymByte(base + offset); + auto high = loadSymByte(base + offset + 1); + return SVFactory::make_smallbv(16, 0).concat(high).concat(low); + } + + SymVal loadSymInt16S(int32_t base, int32_t offset) { + auto value = loadSymInt16U(base, offset); + auto shift = SVFactory::make_concrete_bv(I32V(16), 32); + return value.shl(shift).shr_s(shift); + } + + SymVal loadSymLong8U(int32_t base, int32_t offset) { + return SVFactory::make_smallbv(56, 0).concat(loadSymByte(base + offset)); + } + + SymVal loadSymLong8S(int32_t base, int32_t offset) { + auto value = loadSymLong8U(base, offset); + auto shift = SVFactory::make_concrete_bv(I64V(56), 64); + return value.shl(shift).shr_s(shift); + } + + SymVal loadSymLong16U(int32_t base, int32_t offset) { + auto low = loadSymByte(base + offset); + auto high = loadSymByte(base + offset + 1); + return SVFactory::make_smallbv(48, 0).concat(high).concat(low); + } + + SymVal loadSymLong16S(int32_t base, int32_t offset) { + auto value = loadSymLong16U(base, offset); + auto shift = SVFactory::make_concrete_bv(I64V(48), 64); + return value.shl(shift).shr_s(shift); + } + + SymVal loadSymLong32U(int32_t base, int32_t offset) { + auto b0 = loadSymByte(base + offset); + auto b1 = loadSymByte(base + offset + 1); + auto b2 = loadSymByte(base + offset + 2); + auto b3 = loadSymByte(base + offset + 3); + return SVFactory::make_smallbv(32, 0).concat(b3).concat(b2).concat(b1).concat(b0); + } + + SymVal loadSymLong32S(int32_t base, int32_t offset) { + auto value = loadSymLong32U(base, offset); + auto shift = SVFactory::make_concrete_bv(I64V(32), 64); + return value.shl(shift).shr_s(shift); + } + + // when loading a symval, we need to concat 4 symbolic values + // This sounds terribly bad for SMT... + // Load a 4-byte symbolic value from memory + // Store a 4-byte symbolic value to memory + std::monostate storeSym(int32_t base, int32_t offset, SymVal value) { + int32_t addr = base + offset; + // Extract 4 bytes from that symbol + SymVal s0 = value.extract(1, 1); + SymVal s1 = value.extract(2, 2); + SymVal s2 = value.extract(3, 3); + SymVal s3 = value.extract(4, 4); + storeSymByte(addr, s0); + storeSymByte(addr + 1, s1); + storeSymByte(addr + 2, s2); + storeSymByte(addr + 3, s3); + return std::monostate{}; + } + + std::monostate storeSymLong(int32_t base, int32_t offset, SymVal value) { + int32_t addr = base + offset; + // TODO: Can we receive a float point symbolic value here? which may produce a bug + SymVal s0 = value.extract(1, 1); + SymVal s1 = value.extract(2, 2); + SymVal s2 = value.extract(3, 3); + SymVal s3 = value.extract(4, 4); + SymVal s4 = value.extract(5, 5); + SymVal s5 = value.extract(6, 6); + SymVal s6 = value.extract(7, 7); + SymVal s7 = value.extract(8, 8); + storeSymByte(addr, s0); + storeSymByte(addr + 1, s1); + storeSymByte(addr + 2, s2); + storeSymByte(addr + 3, s3); + storeSymByte(addr + 4, s4); + storeSymByte(addr + 5, s5); + storeSymByte(addr + 6, s6); + storeSymByte(addr + 7, s7); + return std::monostate{}; + } + + std::monostate storeSymInt8(int32_t base, int32_t offset, SymVal value) { + int32_t addr = base + offset; + storeSymByte(addr, value.extract(1, 1)); + return std::monostate{}; + } + + std::monostate storeSymInt16(int32_t base, int32_t offset, SymVal value) { + int32_t addr = base + offset; + storeSymByte(addr, value.extract(1, 1)); + storeSymByte(addr + 1, value.extract(2, 2)); + return std::monostate{}; + } + + std::monostate storeSymLong8(int32_t base, int32_t offset, SymVal value) { + int32_t addr = base + offset; + storeSymByte(addr, value.extract(1, 1)); + return std::monostate{}; + } + + std::monostate storeSymLong16(int32_t base, int32_t offset, SymVal value) { + int32_t addr = base + offset; + storeSymByte(addr, value.extract(1, 1)); + storeSymByte(addr + 1, value.extract(2, 2)); + return std::monostate{}; + } + + std::monostate storeSymLong32(int32_t base, int32_t offset, SymVal value) { + int32_t addr = base + offset; + storeSymByte(addr, value.extract(1, 1)); + storeSymByte(addr + 1, value.extract(2, 2)); + storeSymByte(addr + 2, value.extract(3, 3)); + storeSymByte(addr + 3, value.extract(4, 4)); + return std::monostate{}; + } + + std::monostate storeSymFloat(int32_t base, int32_t offset, SymVal value) { + assert(value.is_concrete() && "Currently only support concrete symbolic " + "value for float-point values"); + return storeSym(base, offset, value); + } + + std::monostate storeSymDouble(int32_t base, int32_t offset, SymVal value) { + assert(value.is_concrete() && "Currently only support concrete symbolic " + "value for float-point values"); + return storeSymLong(base, offset, value); + } + + std::monostate storeSymByte(int32_t addr, SymVal value) { + // assume the input value is 8-bit symbolic value + bool exists; +#ifdef USE_IMM + auto it = memory.find(addr); + exists = (it != nullptr); +#else + auto it = memory.find(addr); + exists = (it != memory.end()); +#endif + auto old_value = loadSymByte(addr); +#ifdef USE_IMM + memory.set(addr, value); +#else + auto inserted = memory.insert({addr, value}); + if (!inserted.second) { + inserted.first->second = value; + } +#endif + return std::monostate{}; + } + + std::monostate reset() { +#ifdef USE_IMM + memory = immer::map_transient(); +#else + memory.clear(); +#endif + return std::monostate{}; + } + + int total_sym_size() const { + ManagedTimer timer(TimeProfileKind::COUNT_SYM_SIZE); + int total_size = 0; + for (const auto &[_, val] : memory) { + total_size += val->size(); + } + return total_size; + } +}; + +inline void SymFrames_t::restore_frame_ptr(Frames_t &frame) const { + frame.frame_ptrs = frame_ptrs; +} + +static SymMemory_t SymMemory; + +static std::monostate memoryInitialize(int32_t offset, + const std::string &data) { + // initialize concrete memory + for (size_t i = 0; i < data.size(); ++i) { + Memory.storeInt(offset, i, static_cast(data[i])); + } + // initialize symbolic memory + for (size_t i = 0; i < data.size(); ++i) { + SymMemory.storeSymByte( + offset + i, SVFactory::make_smallbv(8, static_cast(data[i]))); + } + return {}; +} + +using NumMap = std::unordered_map; + +// TODO: remove this class later +class ImmNumMapBox { +public: + ImmNumMapBox(const NumMap &sym_env) + : map_ptr(std::make_shared( + sym_env) /* create a immutable copy of SymEnv */ + ) {} + + const NumMap *operator->() const { return map_ptr.get(); } + const NumMap &operator*() const { return *map_ptr; } + +private: + std::shared_ptr map_ptr; +}; + +class SymEnv_t { +public: + SymEnv_t() : map(), imm_map_box(map) {} + + Num read(const Symbol &symbol) const { +#if DEBUG + std::cout << "Read symbol: " << symbol.get_id() + << " from symbolic environment" << std::endl; + std::cout << "Current symbolic environment: " << to_string() << std::endl; +#endif + if (map.find(symbol.get_id()) == map.end()) { + return Num(I32V(0)); + } + return map.at(symbol.get_id()); + } + + Num read(SymVal sym) { + // Read the value of a symbolic value from the environment, it will update + // the environment if the key does not exist. + auto symbol = dynamic_cast(sym.symptr.get()); + assert(symbol); + return read(*symbol); + } + + void update(NumMap new_env) { + map = std::move(new_env); + imm_map_box = ImmNumMapBox(map); + } + + // Absorb another symbolic environment into this one, if some keys not exist + // in another environment and exist in this one, they will be kept unchanged. + void absorb(const NumMap &other) { + for (const auto &[id, num] : other) { + map[id] = num; + } + imm_map_box = ImmNumMapBox(map); + } + + std::string to_string() const { + std::string result; + result += "(\n"; + for (const auto &[id, num] : map) { + result += + " (" + std::to_string(id) + "->" + std::to_string(num.value) + ")\n"; + } + result += ")"; + return result; + } + + size_t size() const { return map.size(); } + + ImmNumMapBox get_num_map() const { return imm_map_box; } + +private: + NumMap map; // The symbolic environment, a vector of Num + ImmNumMapBox imm_map_box; +}; + +static SymEnv_t SymEnv; + +// A snapshot of the symbolic state and execution context (control) +class Snapshot_t { +public: + explicit Snapshot_t(Cont_t cont, MCont_t mcont, SymStack_t stack, + SymFrames_t frames, + SymFrames_t globals, SymMemory_t memory, ImmNumMapBox num_map /* Current num map that corresponds to the symbolic environment */); + + SymStack_t get_stack() const { return stack; } + SymFrames_t get_frames() const { return frames; } + SymFrames_t get_globals() const { return globals; } + SymMemory_t get_memory() const { return memory; } + + std::monostate resume_execution(NodeBox *node) const; + std::monostate resume_execution_by_model(NodeBox *node, + z3::model &model) const; + + double cost_of_snapshot() const; + +private: + SymStack_t stack; + SymFrames_t frames; + SymFrames_t globals; + SymMemory_t memory; + // The continuation at the snapshot point + Cont_t cont; + MCont_t mcont; + ImmNumMapBox num_map; + void restore_states_to_global() const; +}; + +static SymFrames_t SymFrames; +static SymFrames_t SymGlobals; + +static Control makeControl(Cont_t cont, MCont_t mcont) { + return Control(cont, mcont); +} + +static Snapshot_t makeSnapshot(Control control) { + // create a snapshot from the current symbolic states and the control + return Snapshot_t(control.cont, control.mcont, SymStack, SymFrames, + SymGlobals, SymMemory, SymEnv.get_num_map()); +} + +struct Node; + +struct NodeBox { + explicit NodeBox(NodeBox *parent); + std::unique_ptr node; + NodeBox *parent; + double instr_cost() const; + + bool fillIfElseNode(SymVal cond, int id); + bool fillCallIndirectNode(SymVal cond, int id); + std::monostate fillFinishedNode(); + std::monostate fillFailedNode(); + std::monostate fillUnreachableNode(); + std::monostate fillSnapshotNode(Snapshot_t snapshot); + std::monostate fillNotToExploreNode(); + bool isUnexplored() const; + bool isSnapshotNode() const; + std::vector collect_path_conds(); + immer::vector collect_path_conds_imm(); + + void reach_here(std::function); + + Node *operator->() { + assert(node != nullptr && "Accessing an empty NodeBox"); + return node.get(); + } +}; + +struct Node { + friend struct NodeBox; + virtual ~Node(){}; + void set_cost(double c) { instr_cost = c; } + double get_cost() const { return instr_cost; } + virtual std::string to_string() = 0; + void to_graphviz(std::ostream &os) { + os << "digraph G {\n"; + os << " rankdir=TB;\n"; + os << " node [shape=box, style=filled, fillcolor=lightblue];\n"; + current_id = 0; + generate_dot(os, -1, ""); + + os << "}\n"; + } + virtual void generate_dot(std::ostream &os, int parent_dot_id, + const std::string &edge_label) = 0; + +protected: + // Counter for unique node IDs across the entire graph, only for generating + // graphviz purpose + static int current_id; + void graphviz_node(std::ostream &os, const int node_id, + const std::string &label, const std::string &shape, + const std::string &fillcolor) { + os << " node" << node_id << " [label=\"" << label << "\", shape=" << shape + << ", style=filled, fillcolor=" << fillcolor << "];\n"; + } + + void graphviz_edge(std::ostream &os, int from_id, int target_id, + const std::string &edge_label) { + os << " node" << from_id << " -> node" << target_id; + if (!edge_label.empty()) { + os << " [label=\"" << edge_label << "\"]"; + } + os << ";\n"; + } + +private: + double instr_cost = 0.0; + std::optional> path_conds_cache; +}; + +inline double NodeBox::instr_cost() const { + if (node) { + return node->get_cost(); + } else { + return 0.0; + } +} + +// TODO: use this header file in multiple compilation units will cause problems +// during linking +int Node::current_id = 0; + +struct IfElseNode : Node { + SymVal cond; + std::unique_ptr true_branch; + std::unique_ptr false_branch; + int id; + + IfElseNode(SymVal cond, NodeBox *parent, int id) + : cond(cond), true_branch(std::make_unique(parent)), + false_branch(std::make_unique(parent)), id(id) {} + + std::string to_string() override { + std::string result = "IfElseNode {\n"; + result += " true_branch: "; + if (true_branch) { + result += true_branch->node->to_string(); + } else { + result += "nullptr"; + } + result += "\n"; + + result += " false_branch: "; + if (false_branch) { + result += false_branch->node->to_string(); + } else { + result += "nullptr"; + } + result += "\n"; + result += "}"; + return result; + } + + void generate_dot(std::ostream &os, int parent_dot_id, + const std::string &edge_label) override { + int current_node_dot_id = current_id; + current_id += 1; + + graphviz_node(os, current_node_dot_id, "If", "diamond", "lightyellow"); + + // Draw edge from parent if this is not the root node + if (parent_dot_id != -1) { + graphviz_edge(os, parent_dot_id, current_node_dot_id, edge_label); + } + assert(true_branch != nullptr); + assert(true_branch->node != nullptr); + true_branch->node->generate_dot(os, current_node_dot_id, "true"); + assert(false_branch != nullptr); + assert(false_branch->node != nullptr); + false_branch->node->generate_dot(os, current_node_dot_id, "false"); + } +}; + +struct CallIndirectNode : Node { + SymVal cond; + std::unordered_map> branches; + std::unique_ptr otherwise_branch; + int id; + CallIndirectNode(SymVal cond, NodeBox *parent, int id) + : cond(cond), id(id), + otherwise_branch(std::make_unique(parent)) {} + std::string to_string() override { + std::string result = "CallIndirectNode {\n"; + for (const auto &pair : branches) { + result += " branch " + std::to_string(pair.first) + ": "; + if (pair.second && pair.second->node) { + result += pair.second->node->to_string(); + } else { + result += "nullptr"; + } + result += "\n"; + } + result += "}"; + return result; + } + + void generate_dot(std::ostream &os, int parent_dot_id, + const std::string &edge_label) override { + int current_node_dot_id = current_id; + current_id += 1; + + graphviz_node(os, current_node_dot_id, "Branch", "diamond", "lightyellow"); + + // Draw edge from parent if this is not the root node + if (parent_dot_id != -1) { + graphviz_edge(os, parent_dot_id, current_node_dot_id, edge_label); + } + for (const auto &pair : branches) { + assert(pair.second != nullptr); + assert(pair.second->node != nullptr); + pair.second->node->generate_dot(os, current_node_dot_id, + "branch " + std::to_string(pair.first)); + } + } +}; + +struct UnExploredNode : Node { + UnExploredNode() {} + std::string to_string() override { return "UnexploredNode"; } + +protected: + void generate_dot(std::ostream &os, int parent_dot_id, + const std::string &edge_label) override { + int current_node_dot_id = current_id++; + graphviz_node(os, current_node_dot_id, "Unexplored", "octagon", + "lightgrey"); + + if (parent_dot_id != -1) { + graphviz_edge(os, parent_dot_id, current_node_dot_id, edge_label); + } + } +}; + +struct NotToExploreNode : Node { + NotToExploreNode() {} + std::string to_string() override { return "NotToExploreNode"; } + +protected: + void generate_dot(std::ostream &os, int parent_dot_id, + const std::string &edge_label) override { + int current_node_dot_id = current_id++; + graphviz_node(os, current_node_dot_id, "NotToExplore", "box", "grey"); + + if (parent_dot_id != -1) { + graphviz_edge(os, parent_dot_id, current_node_dot_id, edge_label); + } + } +}; + +struct SnapshotNode : Node { + SnapshotNode(Snapshot_t snapshot) : snapshot(snapshot) {} + std::string to_string() override { return "SnapshotNode"; } + const Snapshot_t &get_snapshot() const { return snapshot; } + Snapshot_t move_out_snapshot() { return std::move(snapshot); } + + bool worth_to_reuse() const { + if (!ENABLE_COST_MODEL) { + // If we are not using cost model, always create snapshot + return REUSE_SNAPSHOT; + } + // find out the best way to reach the current position via our cost model + auto snapshot_cost = snapshot.cost_of_snapshot(); + double re_execution_cost = get_cost(); + // std::cout << "Snapshot cost: " << snapshot_cost + // << ", re-execution cost: " << re_execution_cost << std::endl; + if (snapshot_cost <= re_execution_cost) { + GENSYM_INFO("Snapshot is worth to create"); + } else { + GENSYM_INFO("Snapshot is NOT worth to create"); + } + return snapshot_cost <= re_execution_cost; + } + +protected: + void generate_dot(std::ostream &os, int parent_dot_id, + const std::string &edge_label) override { + int current_node_dot_id = current_id++; + graphviz_node(os, current_node_dot_id, "Snapshot", "box", "lightblue"); + + if (parent_dot_id != -1) { + graphviz_edge(os, parent_dot_id, current_node_dot_id, edge_label); + } + } + +private: + Snapshot_t snapshot; +}; + +struct Finished : Node { + Finished() {} + std::string to_string() override { return "FinishedNode"; } + +protected: + void generate_dot(std::ostream &os, int parent_dot_id, + const std::string &edge_label) override { + int current_node_dot_id = current_id++; + graphviz_node(os, current_node_dot_id, "Finished", "box", "lightgreen"); + + if (parent_dot_id != -1) { + graphviz_edge(os, parent_dot_id, current_node_dot_id, edge_label); + } + } +}; + +struct Failed : Node { + Failed() {} + std::string to_string() override { return "FailedNode"; } + +protected: + void generate_dot(std::ostream &os, int parent_dot_id, + const std::string &edge_label) override { + int current_node_dot_id = current_id++; + graphviz_node(os, current_node_dot_id, "Failed", "box", "red"); + + if (parent_dot_id != -1) { + graphviz_edge(os, parent_dot_id, current_node_dot_id, edge_label); + } + } +}; + +struct Unreachable : Node { + Unreachable() {} + std::string to_string() override { return "UnreachableNode"; } + +protected: + void generate_dot(std::ostream &os, int parent_dot_id, + const std::string &edge_label) override { + int current_node_dot_id = current_id++; + graphviz_node(os, current_node_dot_id, "Unreachable", "box", "orange"); + + if (parent_dot_id != -1) { + graphviz_edge(os, parent_dot_id, current_node_dot_id, edge_label); + } + } +}; + +inline NodeBox::NodeBox(NodeBox *parent) + : node(std::make_unique()), + /* TODO: avoid allocation of unexplored node */ + parent(parent) {} + +inline bool NodeBox::fillIfElseNode(SymVal cond, int id) { + // fill the current NodeBox with an ifelse branch node when it's unexplored + double cost_from_parent = CostManager.dump_instr_cost(); + double cost_from_root = + cost_from_parent + (this->parent ? this->parent->instr_cost() : 0); + // std::cout << "Cost from parent: " << cost_from_parent + // << ", cost from root: " << cost_from_root << std::endl; + + if (auto ptr = dynamic_cast(node.get())) { + node = std::make_unique(cond, this, id); + node->set_cost(cost_from_root); + return true; + } else if (dynamic_cast(node.get())) { + node = std::make_unique(cond, this, id); + node->set_cost(cost_from_root); + return true; + } else if (dynamic_cast(node.get()) != nullptr) { + assert(false && + "Unexpected traversal: arrived at a node marked 'NotToExplore'."); + return false; + } + + node->set_cost(cost_from_root); + assert( + dynamic_cast(node.get()) != nullptr && + "Current node is not an Unexplored nor an IfElseNode, cannot fill it!"); + return false; +} + +inline bool NodeBox::fillCallIndirectNode(SymVal cond, int id) { + // fill the current NodeBox with a call_indirect branch node when it's + // unexplored + if (auto ptr = dynamic_cast(node.get())) { + node = std::make_unique(cond, this, id); + return true; + } else if (dynamic_cast(node.get())) { + node = std::make_unique(cond, this, id); + return true; + } else if (dynamic_cast(node.get()) != nullptr) { + assert(false && + "Unexpected traversal: arrived at a node marked 'NotToExplore'."); + return false; + } + + assert( + dynamic_cast(node.get()) != nullptr && + "Current node is not an Unexplored nor a CallIndirectNode, cannot fill " + "it!"); + return false; +} + +inline std::monostate NodeBox::fillSnapshotNode(Snapshot_t snapshot) { + if (this->isUnexplored()) { + node = std::make_unique(snapshot); + } + node->set_cost(parent->instr_cost()); + return std::monostate(); +} + +inline std::monostate NodeBox::fillNotToExploreNode() { + if (this->isUnexplored()) { + node = std::make_unique(); + } else { + assert(dynamic_cast(node.get()) != nullptr); + } + return std::monostate(); +} + +inline std::monostate NodeBox::fillFinishedNode() { + if (this->isUnexplored()) { + node = std::make_unique(); + } else { + assert(dynamic_cast(node.get()) != nullptr); + } + return std::monostate(); +} + +inline std::monostate NodeBox::fillFailedNode() { + if (this->isUnexplored()) { + node = std::make_unique(); + } else { + assert(dynamic_cast(node.get()) != nullptr); + } + return std::monostate(); +} + +inline std::monostate NodeBox::fillUnreachableNode() { + if (this->isUnexplored()) { + node = std::make_unique(); + } else { + assert(dynamic_cast(node.get()) != nullptr); + } + return std::monostate(); +} + +inline bool NodeBox::isSnapshotNode() const { + assert(node != nullptr); + return dynamic_cast(node.get()) != nullptr; +} + +inline bool NodeBox::isUnexplored() const { + assert(node != nullptr); + if (dynamic_cast(node.get()) != nullptr) { + return true; + } + if (this->isSnapshotNode()) { + return true; + } + return false; +} + +inline std::vector NodeBox::collect_path_conds() { + ManagedTimer timer(TimeProfileKind::COLLECT_PATH_CONDITIONS); + auto box = this; + auto result = std::vector(); + while (box->parent) { + auto parent = box->parent; + if (auto if_else_node = dynamic_cast(parent->node.get())) { + if (if_else_node->true_branch.get() == box) { + // If the current box is the true branch, add the condition + result.push_back(if_else_node->cond); + } else if (if_else_node->false_branch.get() == box) { + // If the current box is the false branch, add the negated condition + result.push_back(if_else_node->cond.bv_negate().bool2bv()); + } else { + throw std::runtime_error("Unexpected node structure in explore tree"); + } + } else if (auto call_indirect_node = + dynamic_cast(parent->node.get())) { + // Find which branch we are in + bool found = false; + for (const auto &pair : call_indirect_node->branches) { + if (pair.second.get() == box) { + // We are in this branch + // Add the condition that leads to this branch + result.push_back( + call_indirect_node->cond.eq(Concrete(I32V(pair.first), 32))); + found = true; + break; + } + } + if (!found) { + // We must be in the otherwise branch + if (call_indirect_node->otherwise_branch.get() != box) { + throw std::runtime_error("Unexpected node structure in explore tree"); + } + // Add the negated conditions for all other branches + SymVal negated_conditions = Concrete(I32V(1), 32); // true + for (const auto &pair : call_indirect_node->branches) { + negated_conditions = negated_conditions.bitwise_and( + call_indirect_node->cond.neq(Concrete(I32V(pair.first), 32))); + } + result.push_back(negated_conditions); + } + } else { + // should never reach here + } + // Move to parent + box = box->parent; + } + return result; +} + +// same as collect_path_conds but return immer::vector, and cache the result +inline immer::vector NodeBox::collect_path_conds_imm() { + ManagedTimer timer(TimeProfileKind::COLLECT_PATH_CONDITIONS); + + auto box = this; + if (box->node->path_conds_cache.has_value()) { + return box->node->path_conds_cache.value(); + } + + if (!box->parent) { + // root node, and no path conditions + immer::vector empty; + box->node->path_conds_cache = empty; + return empty; + } + + auto parent_conds = box->parent->collect_path_conds_imm(); + immer::vector result = parent_conds; + if (auto if_else_node = dynamic_cast(box->parent->node.get())) { + if (if_else_node->true_branch.get() == box) { + // If the current box is the true branch, add the condition + result = result.push_back(if_else_node->cond); + } else if (if_else_node->false_branch.get() == box) { + // If the current box is the false branch, add the negated condition + result = result.push_back(if_else_node->cond.bv_negate().bool2bv()); + } else { + throw std::runtime_error("Unexpected node structure in explore tree"); + } + } else if (auto call_indirect_node = + dynamic_cast(box->parent->node.get())) { + // Find which branch we are in + bool found = false; + for (const auto &pair : call_indirect_node->branches) { + if (pair.second.get() == box) { + // We are in this branch + // Add the condition that leads to this branch + result = result.push_back( + call_indirect_node->cond.eq(Concrete(I32V(pair.first), 32))); + found = true; + break; + } + } + if (!found) { + // We must be in the otherwise branch + if (call_indirect_node->otherwise_branch.get() != box) { + throw std::runtime_error("Unexpected node structure in explore tree"); + } + // Add the negated conditions for all other branches + SymVal negated_conditions = Concrete(I32V(1), 32); // true + for (const auto &pair : call_indirect_node->branches) { + negated_conditions = negated_conditions.bitwise_and( + call_indirect_node->cond.neq(Concrete(I32V(pair.first), 32))); + } + result = result.push_back(negated_conditions); + } + } else { + // should never reach here + } + box->node->path_conds_cache = result; + return result; +} + +inline Snapshot_t::Snapshot_t(Cont_t cont, MCont_t mcont, SymStack_t stack, + SymFrames_t frames, SymFrames_t globals, + SymMemory_t memory, ImmNumMapBox num_map) + : stack(std::move(stack)), frames(std::move(frames)), + globals(std::move(globals)), memory(std::move(memory)), cont(cont), + mcont(mcont), num_map(num_map) { + Profile.step(StepProfileKind::SNAPSHOT_CREATE); +#ifdef DEBUG + std::cout << "Creating snapshot of size " << stack.size() << std::endl; +#endif +} + +const double INSTR_COST_SCALING_FACTOR = 1E-03; + +inline double Snapshot_t::cost_of_snapshot() const { + auto stack_sym_size = stack.total_sym_size(); + assert(stack_sym_size >= 0); + auto frame_sym_size = frames.total_sym_size(); + assert(frame_sym_size >= 0); + auto memory_sym_size = memory.total_sym_size(); + assert(memory_sym_size >= 0); + auto global_sym_size = globals.total_sym_size(); + assert(global_sym_size >= 0); + // The speed ratio between symbolic expression instantiation and WebAssembly + // instruction execution, given by benchmark results + auto total_size = + stack_sym_size + frame_sym_size + memory_sym_size + global_sym_size; + return INSTR_COST_SCALING_FACTOR * total_size; +} + +struct OverallResult { + int unexplored_count = 0; + int finished_count = 0; + int failed_count = 0; + int not_to_explore_count = 0; + int unreachable_count = 0; + + void print() { + std::cout << "Explore Tree Overall Result:" << std::endl; + std::cout << " Unexplored paths: " << unexplored_count << std::endl; + std::cout << " Finished paths: " << finished_count << std::endl; + std::cout << " Failed paths: " << failed_count << std::endl; + std::cout << " Unreachable paths: " << unreachable_count << std::endl; + std::cout << " NotToExplore paths: " << not_to_explore_count << std::endl; + } +}; + +class ExploreTree_t { +public: + explicit ExploreTree_t() + : root(std::make_unique(nullptr)), cursor(root.get()) {} + + void reset_cursor() { + GENSYM_INFO("Resetting cursor to root"); + // Reset the cursor to the root of the tree + cursor = root.get(); + } + + void clear() { + GENSYM_INFO("Clearing the explore tree"); + root = std::make_unique(nullptr); + cursor = root.get(); + true_branch_cov_map.clear(); + false_branch_cov_map.clear(); + } + + void set_cursor(NodeBox *new_cursor) { + GENSYM_INFO("Setting cursor to a new node"); + cursor = new_cursor; + assert(dynamic_cast(cursor->node.get()) != nullptr); + } + + std::monostate fillFinishedNode() { return cursor->fillFinishedNode(); } + + std::monostate fillFailedNode() { return cursor->fillFailedNode(); } + + std::monostate fillIfElseNode(SymVal cond, int id) { + if (cursor->fillIfElseNode(cond, id)) { + auto if_else_node = dynamic_cast(cursor->node.get()); + register_new_node(if_else_node->true_branch.get()); + register_new_node(if_else_node->false_branch.get()); + } + return std::monostate(); + } + + std::monostate fillCallIndirectNode(SymVal cond, int id) { + if (cursor->fillCallIndirectNode(cond, id)) { + auto indirect_node = dynamic_cast(cursor->node.get()); + register_new_node(indirect_node->otherwise_branch.get()); + } + return std::monostate(); + } + + std::monostate fillNotToExploredNode() { + return cursor->fillNotToExploreNode(); + } + + std::vector collect_current_path_conds() { + return cursor->collect_path_conds(); + } + + std::monostate moveCursor(bool branch, Control control) { + Profile.step(StepProfileKind::CURSOR_MOVE); + assert(cursor != nullptr); + auto if_else_node = dynamic_cast(cursor->node.get()); + assert( + if_else_node != nullptr && + "Can't move cursor when the branch node is not initialized correctly!"); + + if (branch) { + true_branch_cov_map[if_else_node->id] = true; + if (if_else_node->cond.is_concrete()) { + if_else_node->false_branch->fillUnreachableNode(); + } else { + if (REUSE_SNAPSHOT && !if_else_node->false_branch->isSnapshotNode()) { + auto snapshot = makeSnapshot(control); + if_else_node->false_branch->fillSnapshotNode(snapshot); + } else { + // Do nothing, the initial value of the branch is an unexplored node + } + } + cursor = if_else_node->true_branch.get(); + } else { + false_branch_cov_map[if_else_node->id] = true; + if (if_else_node->cond.is_concrete()) { + if_else_node->true_branch->fillUnreachableNode(); + } else { + if (REUSE_SNAPSHOT && !if_else_node->true_branch->isSnapshotNode()) { + auto snapshot = makeSnapshot(control); + if_else_node->true_branch->fillSnapshotNode(snapshot); + } else { + // Do nothing, the initial value of the branch is an unexplored node + } + } + cursor = if_else_node->false_branch.get(); + } + CostManager.reset_timer(); + return std::monostate(); + } + + std::monostate moveCursorNoControl(bool branch) { + Profile.step(StepProfileKind::CURSOR_MOVE); + assert(cursor != nullptr); + auto if_else_node = dynamic_cast(cursor->node.get()); + assert( + if_else_node != nullptr && + "Can't move cursor when the branch node is not initialized correctly!"); + if (branch) { + true_branch_cov_map[if_else_node->id] = true; + if_else_node->false_branch->fillNotToExploreNode(); + cursor = if_else_node->true_branch.get(); + } else { + assert(false && + "moveCursorNoControl should not be used for false branch"); + } + CostManager.reset_timer(); + return std::monostate(); + } + + std::monostate moveCursorIndirect(int branch_index) { + // Dont use snapshot reuse for untaken branches of indirect call + Profile.step(StepProfileKind::CURSOR_MOVE); + assert(cursor != nullptr); + auto branch_node = dynamic_cast(cursor->node.get()); + assert(branch_node != nullptr && + "Can't move cursor when the branch node is not initialized "); + if (branch_node->branches.find(branch_index) == + branch_node->branches.end()) { + // Create a new branch + branch_node->branches[branch_index] = std::make_unique(cursor); + register_new_node(branch_node->branches[branch_index].get()); + } + cursor = branch_node->branches[branch_index].get(); + + return std::monostate(); + } + + std::monostate print() { + std::cout << root->node->to_string() << std::endl; + return std::monostate(); + } + + std::monostate to_graphviz(std::ostream &os) { + root->node->to_graphviz(os); + return std::monostate(); + } + + std::monostate dump_graphviz(std::string filepath) { + std::filesystem::path out_path(filepath); + auto parent = out_path.parent_path(); + if (!parent.empty()) { + std::error_code ec; + std::filesystem::create_directories(parent, ec); + if (ec) { + throw std::runtime_error("Failed to create output directory: " + + ec.message()); + } + } + std::ofstream ofs(filepath); + if (!ofs.is_open()) { + throw std::runtime_error("Failed to open " + filepath + " for writing"); + } + to_graphviz(ofs); + return std::monostate(); + } + + OverallResult read_current_overall_result() { + OverallResult result; + std::vector stack; + stack.push_back(root.get()); + + while (!stack.empty()) { + NodeBox *node = stack.back(); + stack.pop_back(); + + if (auto if_else_node = dynamic_cast(node->node.get())) { + stack.push_back(if_else_node->true_branch.get()); + stack.push_back(if_else_node->false_branch.get()); + } else if (dynamic_cast(node->node.get())) { + result.unexplored_count += 1; + } else if (dynamic_cast(node->node.get())) { + result.finished_count += 1; + } else if (dynamic_cast(node->node.get())) { + result.failed_count += 1; + } else if (dynamic_cast(node->node.get())) { + result.unreachable_count += 1; + } else if (dynamic_cast(node->node.get())) { + // Snapshot node is considered unexplored + result.unexplored_count += 1; + } else if (dynamic_cast(node->node.get())) { + result.not_to_explore_count += 1; + } else if (auto call_indirect_node = + dynamic_cast(node->node.get())) { + for (const auto &pair : call_indirect_node->branches) { + stack.push_back(pair.second.get()); + } + stack.push_back(call_indirect_node->otherwise_branch.get()); + } else { + throw std::runtime_error("Unknown node type in explore tree"); + } + } + return result; + } + + std::monostate print_overall_result() {} + + NodeBox *pick_unexplored() { + // Pick an unexplored node from the tree + // For now, we just iterate through the tree and return the first unexplored + return pick_unexplored_of(root.get()); + } + std::vector true_branch_cov_map; + std::vector false_branch_cov_map; + bool all_branch_covered() const { + for (bool covered : true_branch_cov_map) { + if (!covered) + return false; + } + for (bool covered : false_branch_cov_map) { + if (!covered) + return false; + } + return true; + } + + NodeBox *get_root() const { return root.get(); } + + void register_new_node_collector(std::function func) { + new_node_collectors.push_back(func); + } + +private: + NodeBox *pick_unexplored_of(NodeBox *node) { + if (node->isUnexplored()) { + return node; + } + auto if_else_node = dynamic_cast(node->node.get()); + if (if_else_node) { + NodeBox *result = pick_unexplored_of(if_else_node->true_branch.get()); + if (result) { + return result; + } + return pick_unexplored_of(if_else_node->false_branch.get()); + } + return nullptr; // No unexplored node found + } + void register_new_node(NodeBox *node) { + for (auto &func : new_node_collectors) { + func(node); + } + } + std::unique_ptr root; + NodeBox *cursor; + std::vector> new_node_collectors; +}; + +static ExploreTree_t ExploreTree; + +static std::monostate reset_stacks() { + Stack.reset(); + SymStack.reset(); + Frames.reset(); + SymFrames.reset(); + Memory.reset(); + SymMemory.reset(); + initRand(); + return std::monostate{}; +} + +[[deprecated]] inline void +NodeBox::reach_here(std::function entrypoint) { + // reach the node of exploration tree with given input (symbolic environment) + if (auto snapshot = dynamic_cast(node.get())) { + assert(REUSE_SNAPSHOT); + auto snap = snapshot->get_snapshot(); + snap.resume_execution(this); + return; + } else if (parent == nullptr) { + // if it's the root node, the only way to reach here is to reset everything + // and start a new execution + assert(this == ExploreTree.get_root() && + "Only the root node can have no parent"); + auto timer = ManagedTimer(TimeProfileKind::INSTR); + ExploreTree.reset_cursor(); + reset_stacks(); + entrypoint(); + return; + } + // Reach the parent node, then from the parent node, we can reach here + // TODO: short circuit the lookup + parent->reach_here(entrypoint); + return; +} + +struct EvalRes { + Num value; + ValueKind kind; + int width; // in bits + EvalRes(Num value, int width, ValueKind kind) + : value(value), width(width), kind(kind) {} +}; + +static EvalRes eval_binary_op(EvalRes lhs_res, EvalRes rhs_res, + BinOperation operation) { + auto lhs = lhs_res.value; + auto rhs = rhs_res.value; + auto lhs_width = lhs_res.width; + auto rhs_width = rhs_res.width; + switch (operation) { + case ADD: + if (lhs_width == 32 && rhs_width == 32) { + return EvalRes(lhs.i32_add(rhs), 32, KindBV); + } else if (lhs_width == 64 && rhs_width == 64) { + return EvalRes(lhs.i64_add(rhs), 64, KindBV); + } else { + assert(false && "TODO"); + } + case SUB: + if (lhs_width == 32 && rhs_width == 32) { + return EvalRes(lhs.i32_sub(rhs), 32, KindBV); + } else if (lhs_width == 64 && rhs_width == 64) { + return EvalRes(lhs.i64_sub(rhs), 64, KindBV); + } else { + assert(false && "TODO"); + } + case MUL: + if (lhs_width == 32 && rhs_width == 32) { + return EvalRes(lhs.i32_mul(rhs), 32, KindBV); + } else if (lhs_width == 64 && rhs_width == 64) { + return EvalRes(lhs.i64_mul(rhs), 64, KindBV); + } else { + assert(false && "TODO"); + } + case DIV: + if (lhs_width == 32 && rhs_width == 32) { + return EvalRes(lhs.i32_div_s(rhs), 32, KindBV); + } else if (lhs_width == 64 && rhs_width == 64) { + return EvalRes(lhs.i64_div_s(rhs), 64, KindBV); + } else { + assert(false && "TODO"); + } + case LT_BOOL: + if (lhs_width == 32 && rhs_width == 32) { + return EvalRes(lhs.i32_lt_s(rhs), 32, KindBool); + } else if (lhs_width == 64 && rhs_width == 64) { + return EvalRes(lhs.i64_lt_s(rhs), 32, KindBool); + } else { + assert(false && "TODO"); + } + case LEQ_BOOL: + if (lhs_width == 32 && rhs_width == 32) { + return EvalRes(lhs.i32_le_s(rhs), 32, KindBool); + } else if (lhs_width == 64 && rhs_width == 64) { + return EvalRes(lhs.i64_le_s(rhs), 32, KindBool); + } else { + assert(false && "TODO"); + } + case GT_BOOL: + if (lhs_width == 32 && rhs_width == 32) { + return EvalRes(lhs.i32_gt_s(rhs), 32, KindBool); + } else if (lhs_width == 64 && rhs_width == 64) { + return EvalRes(lhs.i64_gt_s(rhs), 32, KindBool); + } else { + assert(false && "TODO"); + } + case GEQ_BOOL: + if (lhs_width == 32 && rhs_width == 32) { + return EvalRes(lhs.i32_ge_s(rhs), 32, KindBool); + } else if (lhs_width == 64 && rhs_width == 64) { + return EvalRes(lhs.i64_ge_s(rhs), 32, KindBool); + } else { + assert(false && "TODO"); + } + case NEQ_BOOL: + if (lhs_width == 32 && rhs_width == 32) { + return EvalRes(lhs.i32_ne(rhs), 32, KindBool); + } else if (lhs_width == 64 && rhs_width == 64) { + return EvalRes(lhs.i64_ne(rhs), 32, KindBool); + } else { + assert(false && "TODO"); + } + case EQ_BOOL: + if (lhs_width == 32 && rhs_width == 32) { + return EvalRes(lhs.i32_eq(rhs), 32, KindBool); + } else if (lhs_width == 64 && rhs_width == 64) { + return EvalRes(lhs.i64_eq(rhs), 32, KindBool); + } else { + assert(false && "TODO"); + } + case B_AND: + if (lhs_width == 32 && rhs_width == 32) { + return EvalRes(lhs.i32_and(rhs), 32, KindBV); + } else if (lhs_width == 64 && rhs_width == 64) { + return EvalRes(lhs.i64_and(rhs), 64, KindBV); + } else { + assert(false && "TODO"); + } + case CONCAT: { + auto conc_value = (lhs.value << rhs_width) | (rhs.value); + auto new_width = lhs_width + rhs_width; + return EvalRes(Num(I64V(conc_value)), new_width, KindBV); + } + case B_XOR: + if (lhs_width == 32 && rhs_width == 32) { + return EvalRes(lhs.i32_xor(rhs), 32, KindBV); + } else if (lhs_width == 64 && rhs_width == 64) { + return EvalRes(lhs.i64_xor(rhs), 64, KindBV); + } else { + assert(false && "TODO"); + } + case B_OR: + if (lhs_width == 32 && rhs_width == 32) { + return EvalRes(lhs.i32_or(rhs), 32, KindBV); + } else if (lhs_width == 64 && rhs_width == 64) { + return EvalRes(lhs.i64_or(rhs), 64, KindBV); + } else { + assert(false && "TODO"); + } + case SHR_U: + if (lhs_width == 32 && rhs_width == 32) { + return EvalRes(lhs.i32_shr_u(rhs), 32, KindBV); + } else if (lhs_width == 64 && rhs_width == 64) { + return EvalRes(lhs.i64_shr_u(rhs), 64, KindBV); + } else { + assert(false && "TODO"); + } + case SHR_S: + if (lhs_width == 32 && rhs_width == 32) { + return EvalRes(lhs.i32_shr_s(rhs), 32, KindBV); + } else if (lhs_width == 64 && rhs_width == 64) { + return EvalRes(lhs.i64_shr_s(rhs), 64, KindBV); + } else { + assert(false && "TODO"); + } + case LTU_BOOL: + if (lhs_width == 32 && rhs_width == 32) { + return EvalRes(lhs.i32_lt_u(rhs), 32, KindBool); + } else if (lhs_width == 64 && rhs_width == 64) { + return EvalRes(lhs.i64_lt_u(rhs), 32, KindBool); + } else { + assert(false && "TODO"); + } + case GTU_BOOL: + if (lhs_width == 32 && rhs_width == 32) { + return EvalRes(lhs.i32_gt_u(rhs), 32, KindBool); + } else if (lhs_width == 64 && rhs_width == 64) { + return EvalRes(lhs.i64_gt_u(rhs), 32, KindBool); + } else { + assert(false && "TODO"); + } + case GEU_BOOL: + if (lhs_width == 32 && rhs_width == 32) { + return EvalRes(lhs.i32_ge_u(rhs), 32, KindBool); + } else if (lhs_width == 64 && rhs_width == 64) { + return EvalRes(lhs.i64_ge_u(rhs), 32, KindBool); + } else { + assert(false && "TODO"); + } + case AND: + return EvalRes(lhs.logical_and(rhs), 32, KindBool); + case OR: + return EvalRes(lhs.logical_or(rhs), 32, KindBool); + default: + assert(false && "Operation not supported in evaluation"); + } +} + +// TODO: reduce the re-computation of the same symbolic expression, it's better +// if it can be done by the smt solver +static EvalRes eval_sym_expr(const SymVal &sym, const SymEnv_t &sym_env) { + Profile.step(StepProfileKind::SYM_EVAL); + assert(sym.symptr != nullptr && "Symbolic expression is null"); + if (auto concrete = dynamic_cast(sym.symptr.get())) { + return EvalRes(concrete->value, concrete->width(), concrete->kind); + } else if (auto extract = dynamic_cast(sym.symptr.get())) { + auto res = eval_sym_expr(extract->value, sym_env); + int high = extract->high; + int low = extract->low; + assert(high >= low && "Invalid extract range"); + int size = high - low + 1; // size in bytes + int64_t mask = (1LL << (size * 8)) - 1; + int64_t extracted_value = (res.value.toInt() >> ((low - 1) * 8)) & mask; + return EvalRes(Num(I64V(extracted_value)), size * 8, KindBV); + } else if (auto operation = dynamic_cast(sym.symptr.get())) { + // If it's a operation, we need to evaluate it + auto lhs_res = eval_sym_expr(operation->lhs, sym_env); + auto rhs_res = eval_sym_expr(operation->rhs, sym_env); + auto lhs = lhs_res.value; + auto rhs = rhs_res.value; + auto lhs_width = lhs_res.width; + auto rhs_width = rhs_res.width; + return eval_binary_op(lhs_res, rhs_res, operation->op); + } else if (auto symbol = dynamic_cast(sym.symptr.get())) { + auto sym_id = symbol->get_id(); + GENSYM_INFO("Reading symbol: " + std::to_string(sym_id)); + return EvalRes(sym_env.read(*symbol), 32, KindBV); + } + throw std::runtime_error("Not supported symbolic expression"); +} + +inline EvalRes eval_sym_expr_by_model(const SymVal &sym, z3::model &model); + +static void resume_conc_stack(const SymStack_t &sym_stack, Stack_t &stack, + SymEnv_t &sym_env) { + stack.resize(sym_stack.size()); + for (size_t i = 0; i < sym_stack.size(); ++i) { + auto sym = sym_stack[i]; + auto res = eval_sym_expr(sym, sym_env); + auto conc = res.value; + stack.set_from_front(i, conc); + } +} + +static void resume_conc_stack_by_model(const SymStack_t &sym_stack, + Stack_t &stack, z3::model &model) { + GENSYM_INFO("Restoring concrete stack from symbolic stack"); + stack.resize(sym_stack.size()); + for (size_t i = 0; i < sym_stack.size(); ++i) { + auto sym = sym_stack[i]; + auto res = eval_sym_expr_by_model(sym, model); + auto conc = res.value; + stack.set_from_front(i, conc); + } +} + +static void resume_conc_frames(const SymFrames_t &sym_frame, Frames_t &frames, + SymEnv_t &sym_env) { + GENSYM_INFO("Restoring concrete frames from symbolic frames"); + frames.resize(sym_frame.size()); + for (size_t i = 0; i < sym_frame.size(); ++i) { + auto sym = sym_frame[i]; + assert(sym.symptr != nullptr); + auto res = eval_sym_expr(sym, sym_env); + auto conc = res.value; + frames.set_from_front(i, conc); + } + sym_frame.restore_frame_ptr(frames); +} + +static void resume_conc_frames_by_model(const SymFrames_t &sym_frame, + Frames_t &frames, z3::model &model) { + GENSYM_INFO("Restoring concrete frames from symbolic frames"); + frames.resize(sym_frame.size()); + for (size_t i = 0; i < sym_frame.size(); ++i) { + auto sym = sym_frame[i]; + assert(sym.symptr != nullptr); + auto res = eval_sym_expr_by_model(sym, model); + auto conc = res.value; + frames.set_from_front(i, conc); + } + sym_frame.restore_frame_ptr(frames); +} + +static void resume_conc_memory(const SymMemory_t &sym_memory, Memory_t &memory, + const SymEnv_t &sym_env) { + GENSYM_INFO("Restoring concrete memory from symbolic memory"); + memory.reset(); + for (const auto &pair : sym_memory.memory) { + int32_t addr = pair.first; + SymVal sym = pair.second; + assert(sym.symptr != nullptr); + auto res = eval_sym_expr(sym, sym_env); + auto conc = res.value; + assert(res.width == 8 && "Memory should only store bytes"); + memory.store_byte(addr, conc.value & 0xFF); + } +} + +static void resume_conc_memory_by_model(const SymMemory_t &sym_memory, + Memory_t &memory, z3::model &model) { + GENSYM_INFO("Restoring concrete memory from symbolic memory"); + memory.reset(); + for (const auto &pair : sym_memory.memory) { + int32_t addr = pair.first; + SymVal sym = pair.second; + assert(sym.symptr != nullptr); + auto res = eval_sym_expr_by_model(sym, model); + auto conc = res.value; + assert(res.width == 8 && "Memory should only store bytes"); + memory.store_byte(addr, conc.value & 0xFF); + } +} + +static void resume_conc_states(const SymStack_t &sym_stack, + const SymFrames_t &sym_frame, + const SymFrames_t &sym_globals, + const SymMemory_t &sym_memory, Stack_t &stack, + Frames_t &frames, Frames_t &globals, + Memory_t &memory, SymEnv_t &sym_env) { + resume_conc_stack(sym_stack, stack, sym_env); + resume_conc_frames(sym_frame, frames, sym_env); + resume_conc_frames(sym_globals, globals, sym_env); + resume_conc_memory(sym_memory, memory, sym_env); +} + +static void resume_conc_states_by_model(const SymStack_t &sym_stack, + const SymFrames_t &sym_frame, + const SymFrames_t &sym_globals, + const SymMemory_t &sym_memory, + Stack_t &stack, Frames_t &frames, + Frames_t &globals, Memory_t &memory, + z3::model &model) { + resume_conc_stack_by_model(sym_stack, stack, model); + resume_conc_frames_by_model(sym_frame, frames, model); + resume_conc_frames_by_model(sym_globals, globals, model); + resume_conc_memory_by_model(sym_memory, memory, model); +} + +inline void Snapshot_t::restore_states_to_global() const { + // Restore the symbolic state from the snapshot + GENSYM_INFO("Reusing symbolic state from snapshot"); + SymStack = stack; + SymFrames = frames; + SymMemory = memory; + SymGlobals = globals; +} + +inline std::monostate +Snapshot_t::resume_execution_by_model(NodeBox *node, z3::model &model) const { + // Reset explore tree's cursor and restore symbolic states + ExploreTree.set_cursor(node); + restore_states_to_global(); + + { + auto timer = ManagedTimer(TimeProfileKind::RESUME_SNAPSHOT); + // Restore the concrete states from the symbolic states + resume_conc_states_by_model(stack, frames, globals, memory, Stack, Frames, + Globals, Memory, model); + } + // Resume execution from the continuation + auto timer = ManagedTimer(TimeProfileKind::INSTR); + CostManager.reset_timer(); + CURRENT_MCONT = mcont; + return cont(std::monostate{}); +} + +[[deprecated]] inline std::monostate +Snapshot_t::resume_execution(NodeBox *node) const { + // Reset explore tree's cursor and restore symbolic states + ExploreTree.set_cursor(node); + restore_states_to_global(); + { + auto timer = ManagedTimer(TimeProfileKind::RESUME_SNAPSHOT); + // Restore the concrete states from the symbolic states + resume_conc_states(stack, frames, globals, memory, Stack, Frames, Globals, + Memory, SymEnv); + } + + // Resume execution from the continuation + auto timer = ManagedTimer(TimeProfileKind::INSTR); + CURRENT_MCONT = mcont; + return cont(std::monostate{}); +} + +#endif // WASM_SYMBOLIC_RT_HPP diff --git a/genwasym_runtime/include/wasm/symbolic_decl.hpp b/genwasym_runtime/include/wasm/symbolic_decl.hpp new file mode 100644 index 00000000..0e2b5bd2 --- /dev/null +++ b/genwasym_runtime/include/wasm/symbolic_decl.hpp @@ -0,0 +1,327 @@ +#ifndef WASM_SYMVAL_REPR_HPP +#define WASM_SYMVAL_REPR_HPP + +#include "symval_decl.hpp" +#include +#include + +enum BinOperation { + ADD, // Addition + SUB, // Subtraction + MUL, // Multiplication + DIV, // Division + DIV_U, // Unsigned division + AND, // Logical AND + OR, // Logical OR + EQ_BOOL, // Equal (return a boolean) TODO: remove bv version of comparison ops + NEQ_BOOL, // Not equal (return a boolean) + LT_BOOL, // Less than (return a boolean) + LTU_BOOL, // Unsigned less than (return a boolean) + LEQ_BOOL, // Less than or equal (return a boolean) + LEU_BOOL, // Unsigned less than or equal (return a boolean) + GT_BOOL, // Greater than (return a boolean) + GTU_BOOL, // Unsigned greater than (return a boolean) + GEQ_BOOL, // Greater than or equal (return a boolean) + GEU_BOOL, // Unsigned greater than or equal (return a boolean) + SHL, // Shift left + SHR_U, // Shift right unsigned + SHR_S, // Shift right signed + REM_U, // Unsigned remainder + B_AND, // Bitwise AND + B_XOR, // Bitwise XOR + B_OR, // Bitwise OR + CONCAT, // Byte-level concatenation +}; + +enum UnaryOperation { + NOT, // bool not + BOOL2BV, // bool to bitvector, + EXTEND, // bitvector extension, extend i32 to i64 +}; + +enum ValueKind { KindBV, KindBool, KindFP }; + +class Symbolic { +public: + Symbolic() {} + virtual ~Symbolic() = default; // Make Symbolic polymorphic + virtual int size() = 0; + virtual ValueKind value_kind() = 0; + virtual int width() = 0; + virtual z3::expr z3_expr(); + +private: + z3::expr build_z3_expr_aux(); + std::optional _z3_expr; +}; + +class Symbol : public Symbolic { +public: + // TODO: add type information to determine the size of bitvector + // for now we just assume that only i32 will be used + Symbol(int id, int width, ValueKind kind) + : id(id), _width(width), _kind(kind) {} + int get_id() const { return id; } + + int size() override { return 1; } + + ValueKind value_kind() override { return _kind; } + int width() override { return _width; } + +private: + int id; + int _width; + ValueKind _kind; +}; + +class Witness : public Symbolic { +public: + int size() override { return 1; } + + ValueKind value_kind() override { return KindBV; } + + int width() override { return 32; } +}; + +class SymConcrete : public Symbolic { +public: + Num value; + ValueKind kind; + SymConcrete(Num num, ValueKind kind, int width) + : value(num), kind(kind), _width(width) {} + + int size() override { return 1; } + + ValueKind value_kind() override { return kind; } + int width() override { return _width; } + +private: + int _width; +}; + +inline int count_dag_size(Symbolic &val); + +// Extract is different from other operations, it only has one symbolic operand, +// the other two operands are constants +// Extract from value, both high and low are inclusive byte indexes +struct SymExtract : public Symbolic { + SymVal value; + int high; + int low; + + SymExtract(SymVal value, int high, int low) + : value(value), high(high), low(low) {} + + int size() override { + if (_cached_dag_size.has_value()) { + return _cached_dag_size.value(); + } + _cached_dag_size = 1 + value->size(); + return _cached_dag_size.value(); + } + + ValueKind value_kind() override { return KindBV; } + + int width() override { return (high - low + 1) * 8; } + +private: + friend std::tuple + count_dag_size_aux(Symbolic &val, std::set &visited); + + std::optional _cached_dag_size; +}; + +struct SymBinary : public Symbolic { + BinOperation op; + SymVal lhs; + SymVal rhs; + + SymBinary(BinOperation op, SymVal lhs, SymVal rhs) + : op(op), lhs(lhs), rhs(rhs) { + auto lhs_kind = lhs->value_kind(); + auto rhs_kind = rhs->value_kind(); + auto lhs_width = lhs->width(); + auto rhs_width = rhs->width(); + + switch (op) { + case ADD: + case SUB: + case MUL: + case DIV: + case DIV_U: + case SHL: + case SHR_U: + case SHR_S: + case REM_U: + case B_AND: + case B_XOR: + case B_OR: + assert(lhs_kind == KindBV && rhs_kind == KindBV); + assert(lhs_width == rhs_width); + _kind = KindBV; + _width = lhs_width; + break; + case CONCAT: + assert(lhs_kind == KindBV && rhs_kind == KindBV); + _kind = KindBV; + _width = lhs_width + rhs_width; + break; + case EQ_BOOL: + case NEQ_BOOL: + case LT_BOOL: + case LTU_BOOL: + case LEQ_BOOL: + case LEU_BOOL: + case GT_BOOL: + case GTU_BOOL: + case GEQ_BOOL: + case GEU_BOOL: + assert(lhs_kind == rhs_kind); + if (lhs_kind == KindBV) { + assert(lhs_width == rhs_width); + } + _kind = KindBool; + _width = 1; + break; + case AND: + case OR: + assert(lhs_kind == KindBool && rhs_kind == KindBool); + assert(lhs_width == 1 && rhs_width == 1); + _kind = KindBool; + _width = 1; + break; + default: + assert(false && "Unhandled binary operation"); + } + } + + int size() override { + if (_cached_dag_size.has_value()) { + return _cached_dag_size.value(); + } + + auto size = count_dag_size(*this); + _cached_dag_size = size; + return size; + } + + int width() override { return _width; } + + ValueKind value_kind() override { return _kind; } + +private: + friend std::tuple + count_dag_size_aux(Symbolic &val, std::set &visited); + std::optional _cached_dag_size; + ValueKind _kind; + int _width; +}; + +struct SymUnary : public Symbolic { + UnaryOperation op; + SymVal value; + + SymUnary(UnaryOperation op, SymVal value) : op(op), value(value) { + switch (op) { + case BOOL2BV: + assert(value->value_kind() == KindBool); + _width = 32; // Only 32 bit bit vector can be converted to boolean and + // vice versa. + break; + case NOT: + _width = 1; + break; + default: + assert(false && "Unknown unary operation"); + } + } + + int width() override { return _width; } + + int size() override { + if (_cached_dag_size.has_value()) { + return _cached_dag_size.value(); + } + _cached_dag_size = 1 + value->size(); + return _cached_dag_size.value(); + } + + ValueKind value_kind() override { + switch (op) { + case NOT: { + return ValueKind::KindBool; + } + case BOOL2BV: { + return ValueKind::KindBV; + } + default: { + assert(false && "Unknown unary operation"); + } + } + } + +private: + friend std::tuple + count_dag_size_aux(Symbolic &val, std::set &visited); + + int _width; + std::optional _cached_dag_size; +}; + +inline std::tuple count_dag_size_aux(Symbolic &val, + std::set &visited) { + if (visited.find(&val) != visited.end()) { + return {0, true}; + } + visited.insert(&val); + + if (auto binary = dynamic_cast(&val)) { + int size = 1; + auto [lhs_size, lhs_sharing] = + count_dag_size_aux(*binary->lhs.symptr, visited); + auto [rhs_size, rhs_sharing] = + count_dag_size_aux(*binary->rhs.symptr, visited); + size += lhs_size + rhs_size; + if (!lhs_sharing && !rhs_sharing) { + // if there is no sharing in two operands, this temporary size is valid + // and reusable + binary->_cached_dag_size = size; + } + return {size, lhs_sharing || rhs_sharing}; + } else if (auto unary = dynamic_cast(&val)) { + int size = 1; + auto [value_size, value_sharing] = + count_dag_size_aux(*unary->value.symptr, visited); + size += value_size; + if (!value_sharing) { + unary->_cached_dag_size = size; + } + return {size, value_sharing}; + + } else if (auto extract = dynamic_cast(&val)) { + int size = 1; + auto [value_size, value_sharing] = + count_dag_size_aux(*extract->value.symptr, visited); + size += value_size; + if (!value_sharing) { + extract->_cached_dag_size = size; + } + return {size, value_sharing}; + } else if (auto symbol = dynamic_cast(&val)) { + return {1, false}; + } else if (auto concrete = dynamic_cast(&val)) { + return {1, false}; + } else if (auto witness = dynamic_cast(&val)) { + assert(false && "Witness should not appear during instruction execution"); + } else { + assert(false && "Unknown symbolic type in dag size counting"); + } +} + +inline int count_dag_size(Symbolic &val) { + std::set visited; + auto [size, _] = count_dag_size_aux(val, visited); + return size; +} + +#endif // WASM_SYMVAL_REPR_HPP diff --git a/genwasym_runtime/include/wasm/symbolic_impl.hpp b/genwasym_runtime/include/wasm/symbolic_impl.hpp new file mode 100644 index 00000000..8c6fb067 --- /dev/null +++ b/genwasym_runtime/include/wasm/symbolic_impl.hpp @@ -0,0 +1,182 @@ +#ifndef WASM_SYMBOLIC_IMPL_HPP +#define WASM_SYMBOLIC_IMPL_HPP + +#include "symbolic_decl.hpp" +#include "wasm/symval_decl.hpp" +#include "wasm/z3_env.hpp" + +inline z3::expr Symbolic::build_z3_expr_aux() { + if (auto sym = dynamic_cast(this)) { + switch (sym->value_kind()) { + + case KindBV: { + return global_z3_ctx().bv_const( + ("s_int" + std::to_string(sym->get_id())).c_str(), width()); + } + case KindBool: { + assert(false && "Symbolic boolean variables are not supported yet"); + } + case KindFP: + if (sym->width() == 32) { + return global_z3_ctx().fpa_const<32>( + ("s_f32" + std::to_string(sym->get_id())).c_str()); + } else if (sym->width() == 64) { + return global_z3_ctx().fpa_const<64>( + ("s_f64" + std::to_string(sym->get_id())).c_str()); + } else { + throw std::runtime_error("Unsupported floating-point width: " + + std::to_string(sym->width())); + } + } + } else if (auto witness = dynamic_cast(this)) { + return global_z3_ctx().bv_const("witness", 32); + } else if (auto concrete = dynamic_cast(this)) { + switch (concrete->kind) { + case KindBool: { + return global_z3_ctx().bool_val(concrete->value.toInt() != 0); + } + case KindBV: { + return global_z3_ctx().bv_val(concrete->value.value, width()); + } + case KindFP: { + if (width() == 32) { + return global_z3_ctx().fpa_val(concrete->value.toF32()); + } else if (width() == 64) { + return global_z3_ctx().fpa_val(concrete->value.toF64()); + } else { + throw std::runtime_error("Unsupported floating-point width: " + + std::to_string(width())); + } + } + } + } else if (auto binary = dynamic_cast(this)) { + auto bit_width = width(); + + z3::expr left = binary->lhs->z3_expr(); + z3::expr right = binary->rhs->z3_expr(); + switch (binary->op) { + case EQ_BOOL: { + return left == right; + } + case NEQ_BOOL: { + return left != right; + } + case AND: { + return left && right; + } + case OR: { + return left || right; + } + case LT_BOOL: { + return left < right; + } + case LTU_BOOL: { + return z3::ult(left, right); + } + case LEQ_BOOL: { + return left <= right; + } + case LEU_BOOL: { + return z3::ule(left, right); + } + case GT_BOOL: { + return left > right; + } + case GTU_BOOL: { + return z3::ugt(left, right); + } + case GEU_BOOL: { + return z3::uge(left, right); + } + case SHL: { + if (bit_width == 32) { + z3::expr shift_mask = global_z3_ctx().bv_val(0x1F, bit_width); + return z3::shl(left, right & shift_mask); + } else if (bit_width == 64) { + z3::expr shift_mask = global_z3_ctx().bv_val(0x3F, bit_width); + return z3::shl(left, right & shift_mask); + } else { + throw std::runtime_error("Unsupported bit width for SHL: " + + std::to_string(bit_width)); + } + } + case SHR_U: { + return z3::lshr(left, right); + } + case SHR_S: { + return z3::ashr(left, right); + } + case REM_U: { + return z3::urem(left, right); + } + case GEQ_BOOL: { + return left >= right; + } + case ADD: { + return left + right; + } + case SUB: { + return left - right; + } + case MUL: { + return left * right; + } + case DIV: { + return left / right; + } + case DIV_U: { + return z3::udiv(left, right); + } + case B_AND: { + return left & right; + } + case B_XOR: { + return left ^ right; + } + case B_OR: { + return left | right; + } + case CONCAT: { + return z3::concat(left, right); + } + default: + throw std::runtime_error("Operation not supported: " + + std::to_string(binary->op)); + } + } else if (auto unary = dynamic_cast(this)) { + auto bit_width = 32; + z3::expr zero_bv = global_z3_ctx().bv_val(0, bit_width); + z3::expr one_bv = global_z3_ctx().bv_val(1, bit_width); + switch (unary->op) { + case NOT: { + return !unary->value->z3_expr(); + } + case BOOL2BV: { + z3::expr bool_expr = unary->value->z3_expr(); + return z3::ite(bool_expr, one_bv, zero_bv); + } + default: + throw std::runtime_error("Unary operation not supported: " + + std::to_string(unary->op)); + } + } else if (auto extract = dynamic_cast(this)) { + assert(extract); + int high = extract->high * 8 - 1; + int low = extract->low * 8 - 8; + auto s = extract->value->z3_expr(); + auto res = s.extract(high, low); + return res; + } + throw std::runtime_error("Unsupported symbolic value type"); +} + +inline z3::expr Symbolic::z3_expr() { + if (_z3_expr.has_value()) { + return *_z3_expr; + } + auto e = build_z3_expr_aux(); + _z3_expr = e; + return e; +} + +#endif // WASM_SYMBOLIC_IMPL_HPP diff --git a/genwasym_runtime/include/wasm/symval_decl.hpp b/genwasym_runtime/include/wasm/symval_decl.hpp new file mode 100644 index 00000000..8f9ac1a3 --- /dev/null +++ b/genwasym_runtime/include/wasm/symval_decl.hpp @@ -0,0 +1,95 @@ +#ifndef WASM_SYMVAL_HPP +#define WASM_SYMVAL_HPP +#include "concrete_num.hpp" +#include +#include +#include + +class Symbolic; + +struct SymVal { + std::shared_ptr symptr; + + SymVal() = delete; + SymVal(std::shared_ptr symptr) : symptr(symptr) {} + + // Create a new i32 symbol value + SymVal makeI32Symbol() const; + // Create a new i64 symbol value + SymVal makeI64Symbol() const; + // Create a new f32 symbol value + SymVal makeF32Symbol() const; + // Create a new f64 symbol value + SymVal makeF64Symbol() const; + + // bitvector arithmetic operations + SymVal is_zero() const; + SymVal add(const SymVal &other) const; + SymVal minus(const SymVal &other) const; + SymVal mul(const SymVal &other) const; + SymVal div(const SymVal &other) const; + SymVal div_u(const SymVal &other) const; + SymVal eq_bool(const SymVal &other) const; + SymVal neq_bool(const SymVal &other) const; + SymVal land(const SymVal &other) const; + SymVal lor(const SymVal &other) const; + SymVal eq(const SymVal &other) const; + SymVal neq(const SymVal &other) const; + SymVal lt(const SymVal &other) const; + SymVal ltu(const SymVal &other) const; + SymVal le(const SymVal &other) const; + SymVal leu(const SymVal &other) const; + SymVal gt(const SymVal &other) const; + SymVal gtu(const SymVal &other) const; + SymVal ge(const SymVal &other) const; + SymVal geu(const SymVal &other) const; + SymVal shl(const SymVal &other) const; + SymVal shr_u(const SymVal &other) const; + SymVal shr_s(const SymVal &other) const; + SymVal bv_negate() const; + SymVal bool_not() const; + SymVal bitwise_and(const SymVal &other) const; + SymVal bitwise_xor(const SymVal &other) const; + SymVal bitwise_or(const SymVal &other) const; + SymVal concat(const SymVal &other) const; + SymVal extract(int high, int low) const; + SymVal bv2bool() const; + SymVal bool2bv() const; + SymVal rem_u(const SymVal &other) const; + SymVal extend_to_i64() const; // only for i32 symbolic values, extend to i64 by sign extension + // TODO: add bitwise operations, and use the underlying bitvector theory + + bool is_concrete() const; + + static SymVal get_witness_symbol(); + + Symbolic *operator->() const { return symptr.get(); } + bool operator==(const SymVal &other) const { return symptr == other.symptr; } +}; + +struct SymValHash { + size_t operator()(const SymVal &key) const { + return std::hash{}(key.symptr.get()); + } +}; + +using SymValSet = std::unordered_set; + +template +using SymValMap = std::unordered_map; + +template inline bool allConcrete(const Args &...args) { + static_assert((std::is_same_v && ...), + "all_concrete only accepts SymVal arguments"); + return (... && args.is_concrete()); +} + +inline SymVal Concrete(Num num, int width); + +[[noreturn]] inline SymVal debug_unreachable(const char* msg) { + std::cerr << "unreachable: " << msg << '\n'; + assert(false && "unreachable reached"); + std::abort(); +} + +#endif // WASM_SYMVAL_HPP diff --git a/genwasym_runtime/include/wasm/symval_factory.hpp b/genwasym_runtime/include/wasm/symval_factory.hpp new file mode 100644 index 00000000..14875764 --- /dev/null +++ b/genwasym_runtime/include/wasm/symval_factory.hpp @@ -0,0 +1,654 @@ +#ifndef WASM_SYMVAL_FACTORY_HPP +#define WASM_SYMVAL_FACTORY_HPP + +#include "heap_mem_bookkeeper.hpp" +#include "symbolic_decl.hpp" +#include "symval_decl.hpp" + +namespace SVFactory { + +SymVal make_concrete_bv(Num num, int width); +SymVal make_concrete_bool(bool b); +SymVal make_int_symbolic(int index, int width); +SymVal make_smallbv(int width, int64_t value); +SymVal make_binary(BinOperation op, const SymVal &lhs, const SymVal &rhs); +SymVal make_unary(UnaryOperation op, const SymVal &value); +SymVal make_extract(const SymVal &value, int high, int low); + +// Core allocator and common constants. +static MemBookKeeper SymBookKeeper; + +static SymVal I32ZERO = + SymVal(SymBookKeeper.allocate(I32V(0), KindBV, 32)); + +static SymVal I64ZERO = + SymVal(SymBookKeeper.allocate(I64V(0), KindBV, 64)); + +static SymVal TRUE = + SymVal(SymBookKeeper.allocate(I32V(1), KindBool, 32)); + +static SymVal FALSE = + SymVal(SymBookKeeper.allocate(I32V(0), KindBool, 32)); + +static SymVal ZeroByte = + SymVal(SymBookKeeper.allocate(I64V(0), KindBV, 8)); + +// Key and hash types. +struct SmallBVKey { + int width; + int64_t value; + SmallBVKey(int width, int64_t value) : width(width), value(value) {} + + bool operator==(const SmallBVKey &other) const { + return width == other.width && value == other.value; + } +}; + +struct SmallBVKeyHash { + size_t operator()(const SmallBVKey &key) const { + size_t h1 = std::hash{}(key.width); + size_t h2 = std::hash{}(key.value); + return h1 ^ (h2 << 1); + } +}; + +struct ExtractKey { + SymVal value; + int high; + int low; + ExtractKey(const SymVal &value, int high, int low) + : value(value), high(high), low(low) {} + + bool operator==(const ExtractKey &other) const { + return value.symptr == other.value.symptr && high == other.high && + low == other.low; + } +}; + +struct ExtractKeyHash { + size_t operator()(const ExtractKey &key) const { + size_t h1 = std::hash{}(key.value.symptr.get()); + size_t h2 = std::hash{}(key.high); + size_t h3 = std::hash{}(key.low); + return h1 ^ (h2 << 1) ^ (h3 << 2); + } +}; + +struct BinOpKey { + BinOperation op; + SymVal lhs; + SymVal rhs; + BinOpKey(BinOperation op, const SymVal &lhs, const SymVal &rhs) + : op(op), lhs(lhs), rhs(rhs) {} + + bool operator==(const BinOpKey &other) const { + return op == other.op && lhs.symptr == other.lhs.symptr && + rhs.symptr == other.rhs.symptr; + } +}; + +struct BinOpKeyHash { + size_t operator()(const BinOpKey &key) const { + size_t h1 = std::hash{}(static_cast(key.op)); + size_t h2 = std::hash{}(key.lhs.symptr.get()); + size_t h3 = std::hash{}(key.rhs.symptr.get()); + return h1 ^ (h2 << 1) ^ (h3 << 2); + } +}; + +struct UnaryOpKey { + UnaryOperation op; + SymVal value; + UnaryOpKey(UnaryOperation op, const SymVal &value) : op(op), value(value) {} + + bool operator==(const UnaryOpKey &other) const { + return op == other.op && value.symptr == other.value.symptr; + } +}; + +struct UnaryOpKeyHash { + size_t operator()(const UnaryOpKey &key) const { + size_t h1 = std::hash{}(static_cast(key.op)); + size_t h2 = std::hash{}(key.value.symptr.get()); + return h1 ^ (h2 << 1); + } +}; + +// Caches. +static std::unordered_map SymbolStore; +static std::unordered_map FPStore; +static std::unordered_map SmallBVStore; +static std::unordered_map + ExtractOperationStore; +static std::unordered_map BinaryOperationStore; +static std::unordered_map + UnaryOperationStore; + +// Factory implementations. +inline SymVal make_concrete_bv(Num num, int width) { + auto key = SmallBVKey(width, num.toInt64()); + auto it = SmallBVStore.find(key); + if (it != SmallBVStore.end()) { + return it->second; + } + + auto new_val = + SymVal(SymBookKeeper.allocate(num, KindBV, width)); + SmallBVStore.insert({key, new_val}); + return new_val; +} + +inline SymVal make_concrete_fp(Num num, int width) { + auto it = FPStore.find(num.toInt64()); + if (it != FPStore.end()) { + return it->second; + } + + auto new_val = + SymVal(SymBookKeeper.allocate(num, KindFP, width)); + FPStore.insert({num.toInt64(), new_val}); + return new_val; +} + +inline SymVal make_concrete_bool(bool b) { + if (b) { + return TRUE; + } else { + return FALSE; + } +} + +inline SymVal make_int_symbolic(int index, int width) { + auto it = SymbolStore.find(index); + if (it != SymbolStore.end()) { + return it->second; + } + SymVal new_symbol = + SymVal(SymBookKeeper.allocate(index, width, KindBV)); + SymbolStore.insert({index, new_symbol}); + return new_symbol; +} + +inline SymVal make_fp_symbolic(int index, int width) { + auto it = SymbolStore.find(index); + if (it != SymbolStore.end()) { + return it->second; + } + SymVal new_symbol = + SymVal(SymBookKeeper.allocate(index, width, KindFP)); + SymbolStore.insert({index, new_symbol}); + return new_symbol; +} + +inline SymVal make_smallbv(int width, int64_t value) { + if (width == 32) { + return make_concrete_bv(I32V(value), width); + } + if (width == 64) { + return make_concrete_bv(I64V(value), width); + } + auto key = SmallBVKey(width, value); + auto it = SmallBVStore.find(key); + if (it != SmallBVStore.end()) { + return it->second; + } + auto new_val = + SymVal(SymBookKeeper.allocate(I64V(value), KindBV, width)); + SmallBVStore.insert({key, new_val}); + return new_val; +} + +inline SymVal make_extract(const SymVal &value, int high, int low) { + assert(value.symptr != nullptr && "Symbolic expression is null in extract"); + assert(high >= low && "Invalid extract range"); + int new_width = (high - low + 1) * 8; + int shift_bits = (low - 1) * 8; + + if (auto concrete = std::dynamic_pointer_cast(value.symptr)) { + // extract from concrete bitvector value + int64_t val = concrete->value.value; + int64_t mask = (1LL << ((high - low + 1) * 8)) - 1; + int64_t new_value = (val >> shift_bits) & mask; + return SVFactory::make_smallbv(new_width, new_value); + } + + // If the value is already an extract, we can merge the two extracts into one + if (auto extract = std::dynamic_pointer_cast(value.symptr)) { + if (extract->low == low && extract->high == high) { + // extracting the same range, return directly + return value; + } + } + + // Otherwise, create a new extract symbolic value + ExtractKey key(value, high, low); + auto it = ExtractOperationStore.find(key); + if (it != ExtractOperationStore.end()) { + return it->second; + } + auto result = SymVal(SymBookKeeper.allocate(value, high, low)); + ExtractOperationStore.insert({key, result}); + return result; +} + +inline SymVal make_binary(BinOperation op, const SymVal &lhs, + const SymVal &rhs) { + assert(lhs.symptr != nullptr && rhs.symptr != nullptr); + + BinOpKey key(op, lhs, rhs); + auto it = BinaryOperationStore.find(key); + if (it != BinaryOperationStore.end()) { + return it->second; + } + + if (auto lhs_concrete = dynamic_cast(lhs.symptr.get())) { + if (auto rhs_concrete = dynamic_cast(rhs.symptr.get())) { + auto lhs_value = lhs_concrete->value; + auto rhs_value = rhs_concrete->value; + auto lhs_width = lhs_concrete->width(); + auto rhs_width = rhs_concrete->width(); + + auto make_eval_bv = [&](Num num, int width) { + auto result = SVFactory::make_concrete_bv(num, width); + BinaryOperationStore.insert({key, result}); + return result; + }; + auto make_eval_bool = [&](Num num) { + auto result = SVFactory::make_concrete_bool(num.value); + BinaryOperationStore.insert({key, result}); + return result; + }; + + switch (op) { + case ADD: + if (lhs_width == 32 && rhs_width == 32) + return make_eval_bv(lhs_value.i32_add(rhs_value), 32); + if (lhs_width == 64 && rhs_width == 64) + return make_eval_bv(lhs_value.i64_add(rhs_value), 64); + break; + case SUB: + if (lhs_width == 32 && rhs_width == 32) + return make_eval_bv(lhs_value.i32_sub(rhs_value), 32); + if (lhs_width == 64 && rhs_width == 64) + return make_eval_bv(lhs_value.i64_sub(rhs_value), 64); + break; + case MUL: + if (lhs_width == 32 && rhs_width == 32) + return make_eval_bv(lhs_value.i32_mul(rhs_value), 32); + if (lhs_width == 64 && rhs_width == 64) + return make_eval_bv(lhs_value.i64_mul(rhs_value), 64); + break; + case DIV: + if (lhs_width == 32 && rhs_width == 32) + return make_eval_bv(lhs_value.i32_div_s(rhs_value), 32); + if (lhs_width == 64 && rhs_width == 64) + return make_eval_bv(lhs_value.i64_div_s(rhs_value), 64); + break; + case LT_BOOL: + if (lhs_width == 32 && rhs_width == 32) + return make_eval_bool(lhs_value.i32_lt_s(rhs_value)); + if (lhs_width == 64 && rhs_width == 64) + return make_eval_bool(lhs_value.i64_lt_s(rhs_value)); + break; + case LEQ_BOOL: + if (lhs_width == 32 && rhs_width == 32) + return make_eval_bool(lhs_value.i32_le_s(rhs_value)); + if (lhs_width == 64 && rhs_width == 64) + return make_eval_bool(lhs_value.i64_le_s(rhs_value)); + break; + case GT_BOOL: + if (lhs_width == 32 && rhs_width == 32) + return make_eval_bool(lhs_value.i32_gt_s(rhs_value)); + if (lhs_width == 64 && rhs_width == 64) + return make_eval_bool(lhs_value.i64_gt_s(rhs_value)); + break; + case GEQ_BOOL: + if (lhs_width == 32 && rhs_width == 32) + return make_eval_bool(lhs_value.i32_ge_s(rhs_value)); + if (lhs_width == 64 && rhs_width == 64) + return make_eval_bool(lhs_value.i64_ge_s(rhs_value)); + break; + case NEQ_BOOL: + if (lhs_width == 32 && rhs_width == 32) + return make_eval_bool(lhs_value.i32_ne(rhs_value)); + if (lhs_width == 64 && rhs_width == 64) + return make_eval_bool(lhs_value.i64_ne(rhs_value)); + break; + case EQ_BOOL: + if (lhs_width == 32 && rhs_width == 32) + return make_eval_bool(lhs_value.i32_eq(rhs_value)); + if (lhs_width == 64 && rhs_width == 64) + return make_eval_bool(lhs_value.i64_eq(rhs_value)); + break; + case B_AND: + if (lhs_width == 32 && rhs_width == 32) + return make_eval_bv(lhs_value.i32_and(rhs_value), 32); + if (lhs_width == 64 && rhs_width == 64) + return make_eval_bv(lhs_value.i64_and(rhs_value), 64); + break; + case CONCAT: { + auto conc_value = (lhs_value.value << rhs_width) | rhs_value.value; + auto new_width = lhs_width + rhs_width; + return make_eval_bv(Num(I64V(conc_value)), new_width); + } + case B_XOR: + if (lhs_width == 32 && rhs_width == 32) + return make_eval_bv(lhs_value.i32_xor(rhs_value), 32); + if (lhs_width == 64 && rhs_width == 64) + return make_eval_bv(lhs_value.i64_xor(rhs_value), 64); + break; + case B_OR: + if (lhs_width == 32 && rhs_width == 32) + return make_eval_bv(lhs_value.i32_or(rhs_value), 32); + if (lhs_width == 64 && rhs_width == 64) + return make_eval_bv(lhs_value.i64_or(rhs_value), 64); + break; + case SHR_U: + if (lhs_width == 32 && rhs_width == 32) + return make_eval_bv(lhs_value.i32_shr_u(rhs_value), 32); + if (lhs_width == 64 && rhs_width == 64) + return make_eval_bv(lhs_value.i64_shr_u(rhs_value), 64); + break; + case SHR_S: + if (lhs_width == 32 && rhs_width == 32) + return make_eval_bv(lhs_value.i32_shr_s(rhs_value), 32); + if (lhs_width == 64 && rhs_width == 64) + return make_eval_bv(lhs_value.i64_shr_s(rhs_value), 64); + break; + case SHL: + if (lhs_width == 32 && rhs_width == 32) + return make_eval_bv(lhs_value.i32_shl(rhs_value), 32); + if (lhs_width == 64 && rhs_width == 64) + return make_eval_bv(lhs_value.i64_shl(rhs_value), 64); + break; + case LTU_BOOL: + if (lhs_width == 32 && rhs_width == 32) + return make_eval_bool(lhs_value.i32_lt_u(rhs_value)); + if (lhs_width == 64 && rhs_width == 64) + return make_eval_bool(lhs_value.i64_lt_u(rhs_value)); + break; + case LEU_BOOL: + if (lhs_width == 32 && rhs_width == 32) + return make_eval_bool(lhs_value.i32_le_u(rhs_value)); + if (lhs_width == 64 && rhs_width == 64) + return make_eval_bool(lhs_value.i64_le_u(rhs_value)); + break; + case GTU_BOOL: + if (lhs_width == 32 && rhs_width == 32) + return make_eval_bool(lhs_value.i32_gt_u(rhs_value)); + if (lhs_width == 64 && rhs_width == 64) + return make_eval_bool(lhs_value.i64_gt_u(rhs_value)); + break; + case GEU_BOOL: + if (lhs_width == 32 && rhs_width == 32) + return make_eval_bool(lhs_value.i32_ge_u(rhs_value)); + if (lhs_width == 64 && rhs_width == 64) + return make_eval_bool(lhs_value.i64_ge_u(rhs_value)); + break; + case AND: + return make_eval_bool(lhs_value.logical_and(rhs_value)); + case OR: + return make_eval_bool(lhs_value.logical_or(rhs_value)); + default: + break; + } + assert(false && "Operation not supported in evaluation"); + } + } + + if (op == EQ_BOOL) { + if (auto lhs_unary = dynamic_cast(lhs.symptr.get())) { + if (auto rhs_concrete = dynamic_cast(rhs.symptr.get())) { + if (lhs_unary->op == BOOL2BV) { + auto rhs_value = rhs_concrete->value; + if (rhs_value.value == 0) { + auto result = lhs_unary->value.bool_not(); + BinaryOperationStore.insert({key, result}); + return result; + } + } + } + } + + if (auto rhs_unary = dynamic_cast(rhs.symptr.get())) { + if (auto lhs_concrete = dynamic_cast(lhs.symptr.get())) { + if (rhs_unary->op == BOOL2BV) { + auto lhs_value = lhs_concrete->value; + if (lhs_value.value == 0) { + auto result = rhs_unary->value.bool_not(); + BinaryOperationStore.insert({key, result}); + return result; + } + } + } + } + } + + if (op == NEQ_BOOL) { + if (auto lhs_unary = dynamic_cast(lhs.symptr.get())) { + if (auto rhs_concrete = dynamic_cast(rhs.symptr.get())) { + if (rhs_concrete->kind == KindBV && rhs_concrete->value.value == 0) { + if (lhs_unary->op == BOOL2BV) { + auto result = lhs_unary->value; + BinaryOperationStore.insert({key, result}); + return result; + } + } + } + } + if (auto rhs_unary = dynamic_cast(rhs.symptr.get())) { + if (auto lhs_concrete = dynamic_cast(lhs.symptr.get())) { + if (lhs_concrete->kind == KindBV && lhs_concrete->value.value == 0) { + if (rhs_unary->op == BOOL2BV) { + auto result = rhs_unary->value; + BinaryOperationStore.insert({key, result}); + return result; + } + } + } + } + } + + if (op == EQ_BOOL && lhs == rhs) { + auto result = SVFactory::make_concrete_bool(true); + BinaryOperationStore.insert({key, result}); + return result; + } + + if (op == NEQ_BOOL && lhs == rhs) { + auto result = SVFactory::make_concrete_bool(false); + BinaryOperationStore.insert({key, result}); + return result; + } + + if ((op == GT_BOOL || op == LT_BOOL || NEQ_BOOL) && lhs == rhs) { + auto result = SVFactory::make_concrete_bool(false); + BinaryOperationStore.insert({key, result}); + return result; + } + + if (op == AND) { + if (auto rhs_concrete = dynamic_cast(rhs.symptr.get())) { + if (rhs_concrete->kind == KindBool && rhs_concrete->value.value == 0) { + auto result = SVFactory::make_concrete_bool(false); + BinaryOperationStore.insert({key, result}); + return result; + } + } + if (auto lhs_concrete = dynamic_cast(lhs.symptr.get())) { + if (lhs_concrete->kind == KindBool && lhs_concrete->value.value == 0) { + auto result = SVFactory::make_concrete_bool(false); + BinaryOperationStore.insert({key, result}); + return result; + } + } + if (auto rhs_concrete = dynamic_cast(rhs.symptr.get())) { + if (rhs_concrete->kind == KindBool && rhs_concrete->value.value != 0) { + BinaryOperationStore.insert({key, lhs}); + return lhs; + } + } + if (auto lhs_concrete = dynamic_cast(lhs.symptr.get())) { + if (lhs_concrete->kind == KindBool && lhs_concrete->value.value != 0) { + BinaryOperationStore.insert({key, rhs}); + return rhs; + } + } + } + + if (op == B_AND) { + if (auto lhs_unary = dynamic_cast(lhs.symptr.get())) { + if (auto rhs_unary = dynamic_cast(rhs.symptr.get())) { + if (lhs_unary->op == BOOL2BV && rhs_unary->op == BOOL2BV) { + auto result = lhs_unary->value.land(rhs_unary->value).bool2bv(); + BinaryOperationStore.insert({key, result}); + return result; + } + } + } + + if (auto rhs_concrete = dynamic_cast(rhs.symptr.get())) { + if (rhs_concrete->kind == KindBV && rhs_concrete->value.value == 1) { + if (auto lhs_unary = dynamic_cast(lhs.symptr.get())) { + if (lhs_unary->op == BOOL2BV) { + BinaryOperationStore.insert({key, lhs}); + return lhs; + } + } + } + } + + if (auto lhs_concrete = dynamic_cast(lhs.symptr.get())) { + if (lhs_concrete->kind == KindBV && lhs_concrete->value.value == 1) { + if (auto rhs_unary = dynamic_cast(rhs.symptr.get())) { + if (rhs_unary->op == BOOL2BV) { + BinaryOperationStore.insert({key, rhs}); + return rhs; + } + } + } + } + } + + auto result = + SymVal(SVFactory::SymBookKeeper.allocate(op, lhs, rhs)); + BinaryOperationStore.insert({key, result}); + return result; +} + +inline SymVal make_unary(UnaryOperation op, const SymVal &value) { + assert(value.symptr != nullptr); + + UnaryOpKey key(op, value); + auto it = UnaryOperationStore.find(key); + if (it != UnaryOperationStore.end()) { + return it->second; + } + + if (op == BOOL2BV) { + if (auto concrete = dynamic_cast(value.symptr.get())) { + auto value_conc = concrete->value; + if (concrete->kind == KindBool) { + if (value_conc.value != 0) { + auto result = SVFactory::make_concrete_bv(Num(I32V(1)), 32); + UnaryOperationStore.insert({key, result}); + return result; + } else { + auto result = SVFactory::make_concrete_bv(Num(I32V(0)), 32); + UnaryOperationStore.insert({key, result}); + return result; + } + } + } + } + + if (op == NOT) { + if (auto concrete = dynamic_cast(value.symptr.get())) { + if (concrete->kind == KindBool) { + auto result = SVFactory::make_concrete_bool(concrete->value.value == 0); + UnaryOperationStore.insert({key, result}); + return result; + } + } + + if (auto inner_unary = dynamic_cast(value.symptr.get())) { + if (inner_unary->op == NOT) { + auto result = inner_unary->value; + UnaryOperationStore.insert({key, result}); + return result; + } + } + + if (auto inner_binary = dynamic_cast(value.symptr.get())) { + BinOperation negated_op; + switch (inner_binary->op) { + case EQ_BOOL: + negated_op = NEQ_BOOL; + break; + case NEQ_BOOL: + negated_op = EQ_BOOL; + break; + case LT_BOOL: + negated_op = GEQ_BOOL; + break; + case GT_BOOL: + negated_op = LEQ_BOOL; + break; + case LEQ_BOOL: + negated_op = GT_BOOL; + break; + case GEQ_BOOL: + negated_op = LT_BOOL; + break; + default: + negated_op = inner_binary->op; + break; + } + if (negated_op != inner_binary->op) { + auto result = SVFactory::make_binary(negated_op, inner_binary->lhs, + inner_binary->rhs); + UnaryOperationStore.insert({key, result}); + return result; + } + } + } + + auto result = SymVal(SVFactory::SymBookKeeper.allocate(op, value)); + UnaryOperationStore.insert({key, result}); + return result; +} + +inline SymVal make_concat(const SymVal &lhs, const SymVal &rhs) { + if (auto lhs_concrete = std::dynamic_pointer_cast(lhs.symptr)) { + if (auto rhs_concrete = + std::dynamic_pointer_cast(rhs.symptr)) { + if (lhs_concrete->kind == KindBV && rhs_concrete->kind == KindBV) { + int new_width = lhs_concrete->width() + rhs_concrete->width(); + int64_t new_value = + (lhs_concrete->value.value << rhs_concrete->width()) | + rhs_concrete->value.value; + return SVFactory::make_smallbv(new_width, new_value); + } + } + } + if (auto extract1 = std::dynamic_pointer_cast(lhs.symptr)) { + if (auto extract2 = std::dynamic_pointer_cast(rhs.symptr)) { + if (extract1->low == extract2->high + 1 && + extract1->value == extract2->value) { + if (extract1->high == 4 && extract2->low == 1) { + // special case for full 4-byte extract concatenation + // TODO: support 64-bit later, this optimization is only valid when we + // only work on 32-bit values + return extract1->value; + } + // two extracts are adjacent, we can merge them + return extract1->value.extract(extract1->high, extract2->low); + } + } + } + return SVFactory::make_binary(CONCAT, lhs, rhs); +} + +} // namespace SVFactory + +#endif // WASM_SYMVAL_FACTORY_HPP diff --git a/genwasym_runtime/include/wasm/symval_impl.hpp b/genwasym_runtime/include/wasm/symval_impl.hpp new file mode 100644 index 00000000..b7b3684c --- /dev/null +++ b/genwasym_runtime/include/wasm/symval_impl.hpp @@ -0,0 +1,208 @@ +#ifndef WASM_SYMVAL_IMPL_HPP +#define WASM_SYMVAL_IMPL_HPP + +#include "symval_decl.hpp" +#include "symval_factory.hpp" +#include "wasm/concrete_num.hpp" + +inline SymVal SymVal::add(const SymVal &other) const { + return SVFactory::make_binary(ADD, *this, other); +} + +inline SymVal SymVal::minus(const SymVal &other) const { + return SVFactory::make_binary(SUB, *this, other); +} + +inline SymVal SymVal::mul(const SymVal &other) const { + return SVFactory::make_binary(MUL, *this, other); +} + +inline SymVal SymVal::div(const SymVal &other) const { + return SVFactory::make_binary(DIV, *this, other); +} + +inline SymVal SymVal::div_u(const SymVal &other) const { + return SVFactory::make_binary(DIV_U, *this, other); +} + +inline SymVal SymVal::land(const SymVal &other) const { + return SVFactory::make_binary(AND, *this, other); +} + +inline SymVal SymVal::lor(const SymVal &other) const { + return SVFactory::make_binary(OR, *this, other); +} + +inline SymVal SymVal::eq_bool(const SymVal &other) const { + return SVFactory::make_binary(EQ_BOOL, *this, other); +} + +inline SymVal SymVal::neq_bool(const SymVal &other) const { + return SVFactory::make_binary(NEQ_BOOL, *this, other); +} + +inline SymVal SymVal::eq(const SymVal &other) const { + return SVFactory::make_binary(EQ_BOOL, *this, other); +} + +inline SymVal SymVal::neq(const SymVal &other) const { + return SVFactory::make_binary(NEQ_BOOL, *this, other); +} + +inline SymVal SymVal::bv2bool() const { + auto rhs = SVFactory::make_concrete_bv(I32V(0), symptr->width()); + return SVFactory::make_binary(NEQ_BOOL, *this, rhs); +} + +inline SymVal SymVal::bool2bv() const { + return SVFactory::make_unary(BOOL2BV, *this); +} + +inline SymVal SymVal::extend_to_i64() const { + return SVFactory::make_unary(EXTEND, *this); +} + +inline SymVal SymVal::lt(const SymVal &other) const { + return SVFactory::make_binary(LT_BOOL, *this, other); +} + +inline SymVal SymVal::ltu(const SymVal &other) const { + return SVFactory::make_binary(LTU_BOOL, *this, other); +} + +inline SymVal SymVal::le(const SymVal &other) const { + return SVFactory::make_binary(LEQ_BOOL, *this, other); +} + +inline SymVal SymVal::leu(const SymVal &other) const { + return SVFactory::make_binary(LEU_BOOL, *this, other); +} + +inline SymVal SymVal::gt(const SymVal &other) const { + return SVFactory::make_binary(GT_BOOL, *this, other); +} + +inline SymVal SymVal::gtu(const SymVal &other) const { + return SVFactory::make_binary(GTU_BOOL, *this, other); +} + +inline SymVal SymVal::ge(const SymVal &other) const { + return SVFactory::make_binary(GEQ_BOOL, *this, other); +} + +inline SymVal SymVal::geu(const SymVal &other) const { + return SVFactory::make_binary(GEU_BOOL, *this, other); +} + +inline SymVal SymVal::shl(const SymVal &other) const { + return SVFactory::make_binary(SHL, *this, other); +} + +inline SymVal SymVal::shr_u(const SymVal &other) const { + return SVFactory::make_binary(SHR_U, *this, other); +} + +inline SymVal SymVal::shr_s(const SymVal &other) const { + return SVFactory::make_binary(SHR_S, *this, other); +} + +inline SymVal SymVal::rem_u(const SymVal &other) const { + return SVFactory::make_binary(REM_U, *this, other); +} + +inline SymVal SymVal::is_zero() const { + return SVFactory::make_binary( + EQ_BOOL, *this, SVFactory::make_concrete_bv(I64V(0), symptr->width())); +} + +inline SymVal SymVal::bv_negate() const { + assert(symptr->width() != 1); + return SVFactory::make_binary( + EQ_BOOL, *this, SVFactory::make_concrete_bv(I64V(0), symptr->width())); +} + +inline SymVal SymVal::bool_not() const { + return SVFactory::make_unary(NOT, *this); +} + +inline SymVal SymVal::concat(const SymVal &other) const { + return SVFactory::make_concat(*this, other); +} + +inline SymVal SymVal::extract(int high, int low) const { + return SVFactory::make_extract(*this, high, low); +} + +inline SymVal SymVal::bitwise_and(const SymVal &other) const { + return SVFactory::make_binary(B_AND, *this, other); +} + +inline SymVal SymVal::bitwise_xor(const SymVal &other) const { + return SVFactory::make_binary(B_XOR, *this, other); +} + +inline SymVal SymVal::bitwise_or(const SymVal &other) const { + return SVFactory::make_binary(B_OR, *this, other); +} + +inline SymVal SymVal::get_witness_symbol() { + static SymVal witness = SymVal(SVFactory::SymBookKeeper.allocate()); + return witness; +} + +inline SymVal SymVal::makeI32Symbol() const { + if (auto concrete = dynamic_cast(symptr.get())) { + auto id = concrete->value.toInt(); + return SVFactory::make_int_symbolic(id, 32); + } + throw std::runtime_error( + "Cannot make symbolic a non-concrete symbolic value"); +} + +inline SymVal SymVal::makeI64Symbol() const { + if (auto concrete = dynamic_cast(symptr.get())) { + auto id = concrete->value.toInt(); + return SVFactory::make_int_symbolic(id, 64); + } + throw std::runtime_error( + "Cannot make symbolic a non-concrete symbolic value"); +} + +inline SymVal SymVal::makeF32Symbol() const { + if (auto concrete = dynamic_cast(symptr.get())) { + auto id = concrete->value.toInt(); + return SVFactory::make_fp_symbolic(id, 32); + } + throw std::runtime_error( + "Cannot make symbolic a non-concrete symbolic value"); +} + +inline SymVal SymVal::makeF64Symbol() const { + auto concrete = dynamic_cast(symptr.get()); + if (concrete) { + auto id = concrete->value.toInt(); + return SVFactory::make_fp_symbolic(id, 64); + } + throw std::runtime_error( + "Cannot make symbolic a non-concrete symbolic value"); +} + +inline bool SymVal::is_concrete() const { + return dynamic_cast(symptr.get()) != nullptr; +} + +inline SymVal Concrete(Num num, int width) { + // std::cout << "Creating concrete value: " << num.toInt() << " with width " + // << width + // << std::endl; + assert(width == 32 || width == 64); + return SVFactory::make_concrete_bv(num, width); +} + +inline SymVal FPConcrete(Num num, int width) { + assert(width == 32 || width == 64); + return SVFactory::make_concrete_fp(num, width); +} + + +#endif // WASM_SYMVAL_IMPL_HPP diff --git a/genwasym_runtime/include/wasm/union_find.hpp b/genwasym_runtime/include/wasm/union_find.hpp new file mode 100644 index 00000000..3ba13915 --- /dev/null +++ b/genwasym_runtime/include/wasm/union_find.hpp @@ -0,0 +1,60 @@ +#ifndef WASM_UNION_FIND_HPP +#define WASM_UNION_FIND_HPP +#include "config.hpp" +#include "immer/map.hpp" +#include +#include +#include + +// TODO: merge this file with headers/gensym/unionfind.hpp with a general implementation in a new PR +class UnionFind { +private: + immer::map_transient parent; + immer::map_transient rank; + +public: + UnionFind() = default; + + int find(int x) const { + auto parent_opt = parent.find(x); + if (!parent_opt) { + return x; + } + if (*parent_opt == x) { + return x; + } + return find(*parent_opt); + } + + void unite(int x, int y) { + int root_x = find(x); + int root_y = find(y); + + if (root_x == root_y) { + return; + } + + auto rank_x_ptr = rank.find(root_x); + auto rank_y_ptr = rank.find(root_y); + int rank_x = rank_x_ptr ? *rank_x_ptr : 0; + int rank_y = rank_y_ptr ? *rank_y_ptr : 0; + + if (rank_x < rank_y) { + parent.set(root_x, root_y); + } else if (rank_x > rank_y) { + parent.set(root_y, root_x); + } else { + parent.set(root_y, root_x); + rank.set(root_x, rank_x + 1); + } + } + + bool connected(int x, int y) const { return find(x) == find(y); } + + void clear() { + parent = immer::map_transient(); + rank = immer::map_transient(); + } +}; + +#endif // WASM_UNION_FIND_HPP \ No newline at end of file diff --git a/genwasym_runtime/include/wasm/utils.hpp b/genwasym_runtime/include/wasm/utils.hpp new file mode 100644 index 00000000..62c5bb8a --- /dev/null +++ b/genwasym_runtime/include/wasm/utils.hpp @@ -0,0 +1,92 @@ +#ifndef UTILS_HPP +#define UTILS_HPP +#include +#include + +#ifndef GENSYM_ASSERT +#define GENSYM_ASSERT(condition) \ + do { \ + if (!(condition)) { \ + throw std::runtime_error(std::string("Assertion failed: ") + " (" + \ + __FILE__ + ":" + std::to_string(__LINE__) + \ + ")"); \ + } \ + } while (0) +#endif + +#ifndef NO_DBG +#define GENSYM_DBG(obj) \ + do { \ + std::cout << "LOG: " << obj << " (" << __FILE__ << ":" \ + << std::to_string(__LINE__) << ")" << std::endl; \ + } while (0) +#else +#define GENSYM_LOG(message) \ + do { \ + } while (0) +#endif + +#ifndef NO_INFO +#define GENSYM_INFO(obj) \ + do { \ + std::cout << obj << std::endl; \ + } while (0) +#else +#define GENSYM_INFO(message) \ + do { \ + } while (0) + +#endif + +#if __cplusplus < 202002L +#include + +inline bool starts_with(const std::string &str, const std::string &prefix) { + return str.size() >= prefix.size() && + std::equal(prefix.begin(), prefix.end(), str.begin()); +} +#else +#include +inline bool starts_with(const std::string &str, const std::string &prefix) { + return str.starts_with(prefix); +} +#endif + +inline std::monostate print_infos() { + std::cout << std::endl; + return std::monostate{}; +} + +template +std::monostate print_infos(const T &first, const Args &...args) { + std::cout << first << " "; + print_infos(args...); + return std::monostate{}; +} + +template +std::monostate info(const T &first, const Args &...args) { +#ifdef DEBUG + print_infos(first, args...); +#endif + return std::monostate{}; +} + +constexpr const char *DEBUG_OPTS_ENV_VAR = "GENSYM_DEBUG"; + +template +std::monostate infoWhen(const char *dbg_option, const Args &...args) { +#ifdef DEBUGWHEN + const char *env_value = std::getenv(DEBUG_OPTS_ENV_VAR); + if (env_value && std::string(env_value).find(std::string(dbg_option)) != + std::string::npos) { + print_infos(args...); + } +#endif + return std::monostate{}; +} + +inline std::monostate get_unit() { return std::monostate{}; } +inline std::monostate get_unit(std::monostate x) { return std::monostate{}; } + +#endif // UTILS_HPP \ No newline at end of file diff --git a/genwasym_runtime/include/wasm/z3_env.hpp b/genwasym_runtime/include/wasm/z3_env.hpp new file mode 100644 index 00000000..f9109406 --- /dev/null +++ b/genwasym_runtime/include/wasm/z3_env.hpp @@ -0,0 +1,36 @@ +#ifndef WASM_Z3_ENV_HPP +#define WASM_Z3_ENV_HPP +#include "z3++.h" + +struct Z3Env { + z3::context z3_ctx; + + Z3Env() : z3_ctx() {} +}; + +static Z3Env GLOBAL_Z3_ENV; + +inline z3::context &global_z3_ctx() { return GLOBAL_Z3_ENV.z3_ctx; } + +// A map from z3 expression id to their ast size +static std::unordered_map Z3ExprSizeMap; + +inline int get_z3_fp_sort_size(const z3::sort &s) { + assert(s.is_fpa()); + return s.fpa_ebits() + s.fpa_sbits(); +} + +static int get_z3_expr_size(const z3::expr &e) { + unsigned id = e.id(); + if (Z3ExprSizeMap.find(id) != Z3ExprSizeMap.end()) { + return Z3ExprSizeMap[id]; + } + unsigned count = 1; // count self + for (unsigned i = 0; i < e.num_args(); i++) { + count += get_z3_expr_size(e.arg(i)); + } + Z3ExprSizeMap[id] = count; + return count; +} + +#endif // WASM_Z3_ENV_HPP diff --git a/genwasym_runtime/lib/genwasym.cpp b/genwasym_runtime/lib/genwasym.cpp new file mode 100644 index 00000000..95233140 --- /dev/null +++ b/genwasym_runtime/lib/genwasym.cpp @@ -0,0 +1,5 @@ +#include "genwasym.h" + +int genwasym_dummy() { + return 0; +} \ No newline at end of file diff --git a/genwasym_runtime/lib/wasm_state_continue.cpp b/genwasym_runtime/lib/wasm_state_continue.cpp new file mode 100644 index 00000000..29621901 --- /dev/null +++ b/genwasym_runtime/lib/wasm_state_continue.cpp @@ -0,0 +1,116 @@ +#include "wasm_state_continue.hpp" + +cont_t fun_ret_cont_stack[1000]; +int fun_ret_cont_stack_ptr = 0; + +void push_fun_ret_cont_stack(cont_t cont) { + fun_ret_cont_stack[fun_ret_cont_stack_ptr++] = cont; +} + +std::monostate pop_fun_ret_cont_stack() { + return fun_ret_cont_stack[--fun_ret_cont_stack_ptr](std::monostate()); +} + +Value I32V(int x) { + Value v; + v.ty = I32; + v.i32 = x; + return v; +} + +State::State(immer::flex_vector memory, immer::flex_vector globals) + : memory(memory), globals(globals) { + for (int i = 0; i < 1000; i++) { + stack[i] = I32V(0); + } + return_stack = immer::vector_transient>(); +} + +Value State::stack_at(int i) { + return stack[i]; +} + +void State::push_stack(Value v) { + stack[stack_ptr++] = v; +} + +Value State::pop_stack() { + return stack[--stack_ptr]; +} + +Value State::peek_stack() { + return stack[stack_ptr - 1]; +} + +void State::print_stack() { + printf("sp: %ld, fp: %ld, Stack: ", stack_ptr, frame_ptr); + for (int i = 0; i < stack_ptr; i++) { + printf("%d ", stack[i].i32); + } + printf("\n"); +} + +Value State::get_local(int i) { + return stack[frame_ptr + i]; +} + +void State::set_local(int i, Value v) { + stack[frame_ptr + i] = v; +} + +void State::return_from_fun(int num_locals, int ret_num) { + remove_stack_range(frame_ptr - num_locals, frame_ptr); + remove_stack_range(frame_ptr + ret_num, stack_ptr); + stack_ptr = frame_ptr - num_locals + ret_num; +} + +void State::bump_frame_ptr() { + frame_ptr = stack_ptr; +} + +void State::set_frame_ptr(int fp) { + frame_ptr = fp; +} + +int State::get_frame_ptr() { + return frame_ptr; +} + +void State::save_frame_ptr() { + tmp_frame_ptr = frame_ptr; +} + +void State::restore_frame_ptr() { + frame_ptr = tmp_frame_ptr; +} + +void State::remove_stack_range(int start, int end) { + for (int i = start; i < end; i++) { + int j = end + (i - start); + if (j < stack_ptr) { + stack[i] = stack[j]; + } else { + stack[i] = I32V(0); + } + } +} + +void State::reverse_top_n(int n) { + for (int i = stack_ptr - n; i < stack_ptr - n / 2; i++) { + int j = stack_ptr - (i - (stack_ptr - n)) - 1; + Value tmp = stack[i]; + stack[i] = stack[j]; + stack[j] = tmp; + } +} + +State global_state = State(immer::flex_vector(), immer::flex_vector()); + +State& init_state(immer::flex_vector memory, + immer::flex_vector globals, + int num_locals) { + global_state = State(memory, globals); + global_state.stack_ptr = num_locals; + global_state.frame_ptr = num_locals; + return global_state; +} \ No newline at end of file diff --git a/src/test/scala/genwasym/CppCompilationTestBase.scala b/src/test/scala/genwasym/CppCompilationTestBase.scala index 8211bab2..aaf8e8f5 100644 --- a/src/test/scala/genwasym/CppCompilationTestBase.scala +++ b/src/test/scala/genwasym/CppCompilationTestBase.scala @@ -56,6 +56,26 @@ abstract class CppCompilationTestBase extends FunSuite { .filter(_.isDirectory) .map(_.getCanonicalPath) + protected lazy val genwasymRuntimeIncludeDir: String = { + val fromEnv = sys.env.get("GENWASYM_RUNTIME_INCLUDE_DIR") + val fromRepo = firstExistingDir(Seq("./genwasym_runtime/include")) + fromEnv.orElse(fromRepo).getOrElse { + throw new RuntimeException( + "Cannot locate GenWasm runtime include directory. Set GENWASYM_RUNTIME_INCLUDE_DIR or check genwasym_runtime/include." + ) + } + } + + protected lazy val genwasymRuntimeLibDir: String = { + val fromEnv = sys.env.get("GENWASYM_RUNTIME_LIB_DIR") + val fromRepo = firstExistingDir(Seq("./genwasym_runtime/build")) + fromEnv.orElse(fromRepo).getOrElse { + throw new RuntimeException( + "Cannot locate GenWasm runtime library directory. Set GENWASYM_RUNTIME_LIB_DIR or build genwasym_runtime." + ) + } + } + protected def prependPath(existing: Option[String], prefix: String): String = existing.filter(_.nonEmpty).map(old => s"$prefix:$old").getOrElse(prefix) @@ -64,6 +84,11 @@ abstract class CppCompilationTestBase extends FunSuite { "DYLD_LIBRARY_PATH" -> prependPath(sys.env.get("DYLD_LIBRARY_PATH"), z3LibDir) ) + protected lazy val genwasymRuntimeEnv: Seq[(String, String)] = Seq( + "LD_LIBRARY_PATH" -> prependPath(sys.env.get("LD_LIBRARY_PATH"), genwasymRuntimeLibDir), + "DYLD_LIBRARY_PATH" -> prependPath(sys.env.get("DYLD_LIBRARY_PATH"), genwasymRuntimeLibDir) + ) + protected def compileGeneratedCpp(source: String, headerFolders: Seq[String], outputCpp: String, @@ -113,16 +138,16 @@ abstract class CppCompilationTestBase extends FunSuite { macroDefs: Seq[String] = Seq.empty): Unit = { compileGeneratedCpp( source = source, - headerFolders = headerFolders, + headerFolders = genwasymRuntimeIncludeDir +: headerFolders, outputCpp = outputCpp, outputExe = outputExe, compiler = compiler, optimizeLevel = optimizeLevel, extraIncludeDirs = immerIncludeDirs :+ z3IncludeDir, macroDefs = macroDefs, - libraryDirs = Seq(z3LibDir), - runtimeLibraryDirs = Seq(z3LibDir), - libraries = Seq("z3") + libraryDirs = Seq(genwasymRuntimeLibDir, z3LibDir), + runtimeLibraryDirs = Seq(genwasymRuntimeLibDir, z3LibDir), + libraries = Seq("genwasym", "z3") ) } @@ -130,7 +155,7 @@ abstract class CppCompilationTestBase extends FunSuite { Process(Seq(exePath), None, env: _*).!! protected def runExeWithZ3(exePath: String, extraEnv: Seq[(String, String)] = Seq.empty): String = - runExe(exePath, z3RuntimeEnv ++ extraEnv) + runExe(exePath, z3RuntimeEnv ++ genwasymRuntimeEnv ++ extraEnv) protected def parseStackValues(output: String): List[Float] = { val startMarker = "Stack contents: \n"