Commit 0a6ca549 authored by Martin Jonáš's avatar Martin Jonáš
Browse files

Implement proper solver using extensions.

parent 58377f1e
Loading
Loading
Loading
Loading
+74 −2
Original line number Diff line number Diff line
@@ -460,7 +460,6 @@ z3::expr ExprSimplifier::RemoveExistentials(z3::expr e)
            z3::symbol current_symbol(*context, z3_symbol);
            auto name = current_symbol.str();


            if (Z3_is_quantifier_forall(*context, ast))
            {
                if (current_sort.is_bool())
@@ -494,6 +493,79 @@ z3::expr ExprSimplifier::RemoveExistentials(z3::expr e)
    }
}

z3::expr ExprSimplifier::SubstituteExistentials(z3::expr e, std::map<std::string, z3::expr>& model, std::vector<std::string>& boundVars)
{
    if (e.is_numeral())
    {
        return e;
    }

    if (e.is_const())
    {
        std::string name = e.to_string();

        if (model.find(name) != model.end())
        {
            z3::expr value =  model.at(name);
            return to_expr(e.ctx(), Z3_translate(value.ctx(), value, e.ctx()));
        }

        return e;
    }
    else if (e.is_var())
    {
        Z3_ast ast = (Z3_ast)e;
        int deBruijnIndex = Z3_get_index_value(*context, ast);
        std::string name = boundVars[boundVars.size() - deBruijnIndex - 1];

        if (model.find(name) != model.end())
        {
            z3::expr value = model.at(name);
            return to_expr(e.ctx(), Z3_translate(value.ctx(), value, e.ctx()));
        }

        return e.is_bool() ? e.ctx().bool_const(name.c_str()) : e.ctx().bv_const(name.c_str(), e.get_sort().bv_size());

        return e;
    }
    else if (e.is_app())
    {
	func_decl dec = e.decl();
	int numArgs = e.num_args();

	expr_vector arguments(*context);
	for (int i = 0; i < numArgs; i++)
        {
	    arguments.push_back(SubstituteExistentials(e.arg(i), model, boundVars));
        }

	expr result = dec(arguments);
	return result;
    }
    else if (e.is_quantifier())
    {
	Z3_ast ast = (Z3_ast)e;
	int boundVariables = Z3_get_quantifier_num_bound(*context, ast);

	for (int i = 0; i < boundVariables; i++)
	{
	    Z3_symbol z3_symbol = Z3_get_quantifier_bound_name(*context, ast, i);
	    symbol current_symbol(*context, z3_symbol);

	    boundVars.push_back(current_symbol.str());
	}

        auto newBody = SubstituteExistentials(e.body(), model, boundVars);
        e = modifyQuantifierBody(e, newBody);

        boundVars.erase(boundVars.end() - boundVariables, boundVars.end());
        return e;
    }

    std::cout << "Unsupported " << e << std::endl;
    exit(1);
}

bool ExprSimplifier::isSentence(const z3::expr &e)
{
    auto item = isSentenceCache.find((Z3_ast)e);
+1 −0
Original line number Diff line number Diff line
@@ -36,6 +36,7 @@ public:
    z3::expr DeCanonizeBoundVariables(const z3::expr&);
    z3::expr StripToplevelExistentials(z3::expr, bool isPositive);
    z3::expr RemoveExistentials(z3::expr);
    z3::expr SubstituteExistentials(z3::expr, std::map<std::string, z3::expr>& model, std::vector<std::string>& boundVars);
    z3::expr ReduceDivRem(const z3::expr&);

private:
+2 −2
Original line number Diff line number Diff line
@@ -48,15 +48,15 @@ void FormulaStats::AddConstant(const z3::expr &e, const z3::sort &s)

void FormulaStats::AddVariable(const std::string &name, const z3::sort &s)
{
    variables.insert(name);

    if (s.is_bool())
    {
	maxBitWidth = std::max(maxBitWidth, 0u);
        variables.insert({name, 0});
    }
    else if (s.is_bv())
    {
	maxBitWidth = std::max(maxBitWidth, s.bv_size());
        variables.insert({name, s.bv_size()});
    }
}

+3 −2
Original line number Diff line number Diff line
@@ -2,6 +2,7 @@
#define FORMULASTATS_H
#include "z3++.h"
#include <set>
#include <map>
#include <iostream>

class FormulaStats
@@ -20,8 +21,8 @@ public:
    unsigned int maxBitWidth = 0;

    std::set<Z3_ast> constantASTs;
    std::set<std::pair<std::string, int>> constants;
    std::set<std::string> variables;
    std::map<std::string, int> constants;
    std::map<std::string, int> variables;
    std::set<std::string> numerals;

    unsigned int numeralCount = 0;
+28 −10
Original line number Diff line number Diff line
@@ -2,6 +2,7 @@

#include <algorithm>
#include <numeric>
#include <regex>

#include "SMTLIBInterpreter.h"
#include "Logger.h"
@@ -56,7 +57,9 @@ void SMTLIBInterpreter::addConstant(const std::string& name, const z3::sort& s)
    }
    else if (s.is_bv())
    {
        constants.insert({name, ctx.bv_const(name.c_str(), s.bv_size())});
        std::regex varRegex ("([^!]+)(![0-9]+)*"); //TODO nedefinovat znovu
        std::string fixedName = std::regex_replace(name, varRegex, "$1");
        constants.insert({fixedName, ctx.bv_const(name.c_str(), s.bv_size())});
    }
}

@@ -70,8 +73,10 @@ z3::expr SMTLIBInterpreter::addVar(const std::string& name, const z3::sort& s)
    }
    else if (s.is_bv())
    {
        auto newVar = ctx.bv_const(name.c_str(), s.bv_size());
        variables.push_back({name, newVar});
        std::regex varRegex ("([^!]+)(![0-9]+)*"); //TODO nedefinovat znovu
        std::string fixedName = std::regex_replace(name, varRegex, "$1");
        auto newVar = ctx.bv_const(fixedName.c_str(), s.bv_size());
        variables.push_back({fixedName, newVar});
        return newVar;
    }
    exit(1);
@@ -84,7 +89,8 @@ void SMTLIBInterpreter::addVarBinding(const std::string& name, const z3::expr& e

void SMTLIBInterpreter::addFunctionDefinition(const std::string& name, const z3::expr_vector& args, const z3::expr& body)
{
    funDefinitions.insert({name, {args, body}});
    std::regex varRegex ("([^!]+)(![0-9]+)*");
    funDefinitions.insert({std::regex_replace(name, varRegex, "$1"), {args, body}});
}

void SMTLIBInterpreter::addSortDefinition(const std::string& name,  const z3::sort& sort)
@@ -94,17 +100,20 @@ void SMTLIBInterpreter::addSortDefinition(const std::string& name, const z3::so

z3::expr SMTLIBInterpreter::getConstant(const std::string& name) const
{
    std::regex varRegex ("([^!]+)(![0-9]+)*"); //TODO nedefinovat znovu
    std::string fixedName = std::regex_replace(name, varRegex, "$1");

    auto varItem = std::find_if(
        variables.rbegin(),
        variables.rend(),
        [name] (const auto& it) { return it.first == name; });
        [fixedName] (const auto& it) { return it.first == fixedName; });

    if (varItem != variables.rend())
    {
        return varItem->second;
    }

    auto item = constants.find(name);
    auto item = constants.find(fixedName);
    if (item != constants.end())
    {
        return item->second;
@@ -113,14 +122,14 @@ z3::expr SMTLIBInterpreter::getConstant(const std::string& name) const
    auto bindItem = std::find_if(
        variableBindings.rbegin(),
        variableBindings.rend(),
        [name] (const auto& it) { return it.first == name; });
        [fixedName] (const auto& it) { return it.first == fixedName; });

    if (bindItem != variableBindings.rend())
    {
        return bindItem->second;
    }

    std::cout << "Unknown constant " << name << std::endl;
    std::cout << "Unknown constant " << fixedName << std::endl;
    exit(1);
}

@@ -307,8 +316,17 @@ antlrcpp::Any SMTLIBInterpreter::visitCommand(SMTLIBv2Parser::CommandContext* co
        {
            expr = expr && z3::mk_and(assert);
        }

        Solver s;
        if (dual)
        {
            result = s.SolveDual(expr);
        }
        else
        {
            result = s.Solve(expr);
        }

        std::cout << (result == SAT ? "sat" :
                      result == UNSAT ? "unsat" :
                      "unknown") << std::endl;
Loading