Commit 6d1fa223 authored by Jindřich Sedláček's avatar Jindřich Sedláček
Browse files

feat(unit propagation): begin refactor, already works with dpll

parent 85ccf7ce
Loading
Loading
Loading
Loading
+14 −13
Original line number Diff line number Diff line
#pragma once

#include <climits>
#include <deque>
#include <vector>

#include "formula.hpp"
@@ -37,10 +36,8 @@ private:

    struct VariableInfo
    {
        // new watch invariant: The watched literals are always the first two
        // literals in a clause
        std::vector<std::pair<ClauseIdx, bool>> watched_negated = {};
        std::vector<std::pair<ClauseIdx, bool>> watched_nonnegated = {};
        std::vector<ClauseIdx> watched_negated = {};
        std::vector<ClauseIdx> watched_nonnegated = {};

        ClauseIdx reason = NO_REASON;
        DecisionLevel decision_level = NO_LEVEL;
@@ -57,14 +54,15 @@ private:
    std::vector<VariableInfo> _variable_infos;
    std::vector<ClauseInfo> _clause_infos;

    std::deque<ClauseIdx> _unit_propagable = {};
    bool _try_propagating_last = false;

    std::vector<Literal> _trail = {};
    std::vector<TrailIdx> _decisions = {};
    DecisionLevel _highest_backtrack_level = NO_LEVEL;
    TrailIdx _earliest_flipped_decision = INT_MAX;

    // Invariant: after a unit propagation, this is equal to _trail.size()
    // Invariant: this is always less or equal to _trail.size()
    TrailIdx _next_unit_propagation = 0;

    std::vector<Variable> _unassigned_in_last_backtrack = {};

    int _original_clause_count;
@@ -81,9 +79,14 @@ public:
    int original_clause_count() const;

    const Clause& get_clause(ClauseIdx idx) const;
    AssignmentState clause_status(const Clause& clause) const;

    // This method is valid when _next_unit_propagation == _trail.size()
    // In other cases, it can return an incorrect value
    AssignmentState clause_status(ClauseIdx idx) const;

    // This method is always valid, but is slower than the previous one
    AssignmentState clause_status(const Clause& clause) const;

    AssignmentState variable_status(Variable variable) const;
    AssignmentState literal_status(Literal literal) const;

@@ -91,8 +94,6 @@ public:

    std::optional<ClauseIdx> reason(Variable variable) const;

    // contains one literal at the current decision level
    bool is_asserting(const Clause& clause) const;
    std::optional<DecisionLevel>
    lowest_unit_propagation_level(const Clause& clause) const;

@@ -138,14 +139,14 @@ private:
        OK
    };

    ConflictResult try_propagating_last_clause();
    void set_watch(ClauseIdx clause_idx, int variable_idx, bool first);
    void set_up_watch(ClauseIdx clause_idx, bool first);

    enum class ReassignWatchResult
    {
        COULD_NOT,
        KEPT_THE_SAME,
        REASSIGNED,
        CONFLICT,
    };

    ReassignWatchResult reassign_watch(ClauseIdx clause_idx, bool first);
+198 −211
Original line number Diff line number Diff line
@@ -15,12 +15,12 @@ Engine::Engine(const Formula& formula)

        if (formula.body[idx].size() == 1)
        {
            _unit_propagable.push_back(idx);
            set_literal_to_true(formula.body[idx].body[0]);
        }
        else
        {
            set_watch(idx, 0, true);
            set_watch(idx, 1, false);
            set_up_watch(idx, true);
            set_up_watch(idx, false);
        }
    }
}
@@ -40,32 +40,6 @@ const Clause& Engine::get_clause(ClauseIdx idx) const
    return _clause_infos[idx].clause;
}

AssignmentState Engine::clause_status(const Clause& clause) const
{
    bool has_unknown = false;

    for (auto lit : clause.body)
    {
        AssignmentState status = literal_status(lit);

        if (status.is_true())
        {
            return AssignmentState::TRUE;
        }
        else if (status.is_unknown())
        {
            has_unknown = true;
        }
    }

    if (has_unknown)
    {
        return AssignmentState::UNKNOWN;
    }

    return AssignmentState::FALSE;
}

AssignmentState Engine::clause_status(ClauseIdx idx) const
{
    const Clause& clause = get_clause(idx);
@@ -98,6 +72,32 @@ AssignmentState Engine::clause_status(ClauseIdx idx) const
    return AssignmentState::UNKNOWN;
}

AssignmentState Engine::clause_status(const Clause& clause) const
{
    bool has_unknown = false;

    for (auto lit : clause.body)
    {
        AssignmentState status = literal_status(lit);

        if (status.is_true())
        {
            return AssignmentState::TRUE;
        }
        else if (status.is_unknown())
        {
            has_unknown = true;
        }
    }

    if (has_unknown)
    {
        return AssignmentState::UNKNOWN;
    }

    return AssignmentState::FALSE;
}

AssignmentState Engine::variable_status(Variable variable) const
{
    assert(variable > 0);
@@ -109,14 +109,9 @@ AssignmentState Engine::literal_status(Literal literal) const
{
    assert(literal.get_variable() > 0);
    assert(literal.get_variable() < _variable_infos.size());
    auto var_status = variable_status(literal.get_variable());

    if (!literal.is_negated())
    {
        return var_status;
    }

    return var_status.flipped();
    auto var_status = variable_status(literal.get_variable());
    return (literal.is_negated()) ? var_status.flipped() : var_status;
}

Engine::Status Engine::status() const
@@ -125,7 +120,9 @@ Engine::Status Engine::status() const

    for (ClauseIdx idx = 0; idx < _clause_infos.size(); idx++)
    {
        const AssignmentState clause_state = clause_status(idx);
        const AssignmentState clause_state = (_next_unit_propagation == _trail.size())
                                                 ? clause_status(idx)
                                                 : clause_status(get_clause(idx));

        if (clause_state.is_false())
        {
@@ -160,26 +157,6 @@ std::optional<Engine::ClauseIdx> Engine::reason(Variable variable) const
    return reason;
}

bool Engine::is_asserting(const Clause& clause) const
{
    bool has_one_literal_at_current_decision_level = false;

    for (Literal lit : clause.body)
    {
        if (decision_level(lit.get_variable()) == current_decision_level())
        {
            if (has_one_literal_at_current_decision_level)
            {
                return false;
            }

            has_one_literal_at_current_decision_level = true;
        }
    }

    return has_one_literal_at_current_decision_level;
}

std::optional<Engine::DecisionLevel>
Engine::lowest_unit_propagation_level(const Clause& clause) const
{
@@ -263,12 +240,7 @@ std::optional<Engine::DecisionLevel> Engine::decision_level(Variable variable) c
    assert(variable < _variable_infos.size());
    auto level = _variable_infos[variable].decision_level;

    if (level == NO_LEVEL)
    {
        return std::nullopt;
    }

    return level;
    return (level == NO_LEVEL) ? std::nullopt : std::optional(level);
}

std::optional<Variable> Engine::last_decision() const
@@ -302,23 +274,82 @@ void Engine::learn(Clause clause)
{
    assert(_clause_infos.size() <= INT_MAX);
    assert(clause.size() >= 1);
    assert(!_try_propagating_last);

    _try_propagating_last = true;
    assert(!clause_status(clause).is_false());

    _clause_infos.emplace_back();
    ClauseInfo& info = _clause_infos.back();
    ClauseIdx idx = _clause_infos.size() - 1;

    info.clause = clause;
    info.clause = std::move(clause);

    if (clause.size() == 1)
    if (info.clause.size() == 1)
    {
        set_literal_to_true(info.clause.body[0]);
        return;
    }

    // we want to set as watched literals either unassigned
    // literals, or the literals that were assigned last
    // what we want here ... we know this clause is not false, so
    // it has at least one literal that is not false
    //
    // if it can be unit propagated, it is unit propagated
    //
    // the watches are set such that decisions and backtracks remain valid

    // we order the literals in the clause such that true literals are
    // first, unknown second, false last

    int true_idx = 0;
    int unknown_idx = 0;
   
    for (int idx = 0; idx < info.clause.size(); idx++)
    {
        AssignmentState state = literal_status(clause.body[idx]);

        if (state.is_true())
        {
            std::swap(clause.body[true_idx], clause.body[unknown_idx]);
            std::swap(clause.body[true_idx], clause.body[idx]);

            unknown_idx++;
            true_idx++;
        }
        else if (state.is_unknown())
        {
            std::swap(clause.body[unknown_idx], clause.body[idx]);
            unknown_idx++;
        }
    }

    // we can unit propagate the unknown (now first) literal
    if (true_idx == 0 && unknown_idx == 1)
    {
        set_literal_to_true(clause.body[0]);
    }

    return;

    bool has_true = false;
    int unassigned_count = 0;
    int idx_of_last_unassigned = -1;

    for (int i = 0; i < info.clause.size(); i++)
    {
        Literal& l = info.clause.body[i];
        const AssignmentState status = literal_status(l);

        has_true = has_true || literal_status(l).is_true();

        if (status.is_unknown())
        {
            unassigned_count++;
            idx_of_last_unassigned = i;
        }
    }

    if (!has_true && unassigned_count)
    {
        set_literal_to_true(info.clause.body[idx_of_last_unassigned]);
    }

    // a dirty trick...
    const int UNASSIGNED = INT_MAX;
@@ -329,9 +360,9 @@ void Engine::learn(Clause clause)
    int candidate_two_level = -1;

    // IMPORTANT: the clause has at least two literals here!
    for (int i = 0; i < clause.size(); i++)
    for (int i = 0; i < info.clause.size(); i++)
    {
        Literal lit = clause.body[i];
        Literal lit = info.clause.body[i];
        const int level = decision_level(lit.get_variable()).value_or(UNASSIGNED);

        if (level > candidate_one_level)
@@ -349,13 +380,27 @@ void Engine::learn(Clause clause)
        }
    }

    Literal candidate_one_literal = info.clause.body[candidate_one];
    Literal candidate_two_literal = info.clause.body[candidate_two];

    if (candidate_one > candidate_two)
    {
        std::swap(candidate_one, candidate_two);
    }

    set_watch(idx, candidate_one, true);
    set_watch(idx, candidate_two, false);
    int one_destination = 0;
    int two_destination = 1;

    if (candidate_two == 0)
    {
        std::swap(candidate_one, candidate_two);
    }

    std::swap(info.clause.body[one_destination], info.clause.body[candidate_one]);
    std::swap(info.clause.body[two_destination], info.clause.body[candidate_two]);

    set_up_watch(idx, true);
    set_up_watch(idx, false);
}

// NOTE: this method is slow, but it should not happen too often
@@ -388,30 +433,28 @@ std::vector<Engine::ClauseIdx> Engine::forget(std::vector<ClauseIdx> indices)

    for (VariableInfo& variable_info : _variable_infos)
    {
        std::vector<std::pair<ClauseIdx, bool>> new_watched_negated = {};
        std::vector<ClauseIdx> new_watched_negated = {};
        for (auto& watched_in : variable_info.watched_negated)
        {
            if (relocations[watched_in.first] == REMOVED)
            if (relocations[watched_in] == REMOVED)
            {
                continue;
            }

            new_watched_negated.push_back(
                {relocations[watched_in.first], watched_in.second});
            new_watched_negated.push_back(relocations[watched_in]);
        }

        variable_info.watched_negated = new_watched_negated;

        std::vector<std::pair<ClauseIdx, bool>> new_watched_nonnegated = {};
        std::vector<ClauseIdx> new_watched_nonnegated = {};
        for (auto& watched_in : variable_info.watched_nonnegated)
        {
            if (relocations[watched_in.first] == REMOVED)
            if (relocations[watched_in] == REMOVED)
            {
                continue;
            }

            new_watched_nonnegated.push_back(
                {relocations[watched_in.first], watched_in.second});
            new_watched_nonnegated.push_back(relocations[watched_in]);
        }

        variable_info.watched_nonnegated = new_watched_nonnegated;
@@ -451,7 +494,6 @@ void Engine::backtrack_to(DecisionLevel level)
    }

    clear_trail(_decisions[level] - 1);
    _unit_propagable.clear();

    // if we backtracked below the highest backtrack level, we need
    // to update the highest backtrack level
@@ -475,8 +517,6 @@ void Engine::flip_last_decision()
    Literal last_decision = _trail.back();
    unassign_last();

    _unit_propagable.clear();

    set_literal_to_true(last_decision.negated());

    // if we backtracked below the highest backtrack level, we need
@@ -497,44 +537,77 @@ void Engine::decide(Variable variable, VariablePhase phase)

std::optional<Engine::ClauseIdx> Engine::unit_propagate()
{
    if (_try_propagating_last &&
        try_propagating_last_clause() == ConflictResult::CONFLICT)
    while (_next_unit_propagation < _trail.size())
    {
        return _clause_infos.size() - 1;
    }
        const Literal literal = _trail[_next_unit_propagation];
        const Variable variable = literal.get_variable();
        VariableInfo& info = _variable_infos[literal.get_variable()];
        const AssignmentState variable_state = info.assignment_state;

        auto& watched =
            (variable_state.is_true()) ? info.watched_negated : info.watched_nonnegated;

        static std::vector<ClauseIdx> to_keep;
        to_keep.clear();

    while (!_unit_propagable.empty())
        std::optional<ClauseIdx> conflict_clause = std::nullopt;

        for (ClauseIdx idx : watched)
        {
            Clause& clause = _clause_infos[idx].clause;
            assert(clause.size() > 1);

            if (literal_status(clause.body[0]).is_true() ||
                literal_status(clause.body[1]).is_true())
            {
        ClauseIdx clause_idx = _unit_propagable.front();
        _unit_propagable.pop_front();
                // do not waste time reassigning watches when the clause is true
                to_keep.push_back(idx);
                continue;
            }

            int to_reassign = (clause.body[0].get_variable() == variable) ? 0 : 1;
            int other_watch = (to_reassign == 0) ? 1 : 0;

        const Clause& clause = get_clause(clause_idx);
            assert(literal_status(clause.body[to_reassign]).is_false());

        int this_watch = 1;
        if (clause.size() == 1 || literal_status(clause.body[1]).is_false())
            bool reassigned = false;
            for (int i = 2; i < clause.size(); i++)
            {
                if (!literal_status(clause.body[i]).is_false())
                {
            this_watch = 0;
                    std::swap(clause.body[i], clause.body[to_reassign]);
                    set_up_watch(idx, to_reassign == 0);
                    reassigned = true;
                    break;
                }
            }

        Literal this_literal = clause.body[this_watch];
        AssignmentState this_status = literal_status(this_literal);
            if (!reassigned)
            {
                to_keep.push_back(idx);

        if (this_status.is_false())
                if (literal_status(clause.body[other_watch]).is_false())
                {
            // assert(clause_status(clause).is_false());
            return clause_idx;
                    // we cannot return here, some watches may be reassigned
                    conflict_clause = idx;
                }
                else // the other literal is unknown
                {
                    set_literal_to_true(clause.body[other_watch]);
                    _variable_infos[clause.body[other_watch].get_variable()].reason = idx;
                }
            }
        }

        watched = to_keep;
        _next_unit_propagation++;

        if (this_status.is_unknown())
        if (conflict_clause.has_value())
        {
            // we can unit propagate!
            set_literal_to_true(this_literal);
            _variable_infos[this_literal.get_variable()].reason = clause_idx;
            return conflict_clause;
        }
    }

    // when we successfully unit-propagated everything in the list, then
    // we can backtrack here in the future!
    _highest_backtrack_level = current_decision_level();
    return std::nullopt;
}
@@ -551,32 +624,8 @@ void Engine::assign(Variable variable, VariablePhase phase)
    info.last_assignment_state.assign(phase);
    info.decision_level = _decisions.size();

    auto& watched = (phase == VariablePhase::POSITIVE) ? info.watched_negated
                                                       : info.watched_nonnegated;

    // the following vector is static; this is an optimization, because
    // we do not want to keep allocating new memory every time we call this
    // method
    static std::vector<std::pair<ClauseIdx, bool>> to_keep;

    to_keep.clear();

    for (auto [clause_idx, first] : watched)
    {
        switch (reassign_watch(clause_idx, first))
        {
        case ReassignWatchResult::COULD_NOT:
            _unit_propagable.push_back(clause_idx);
            // warning: fall through!
        case ReassignWatchResult::KEPT_THE_SAME:
            to_keep.push_back({clause_idx, first});
            break;
        case ReassignWatchResult::REASSIGNED:
            break;
        }
    }

    std::swap(watched, to_keep);
    // we do not change watches here anymore, that happens in unit propagation
    return;
}

void Engine::set_literal_to_true(Literal lit) { assign(lit.get_variable(), lit.phase()); }
@@ -595,6 +644,7 @@ void Engine::unassign_last()
    info.assignment_state.unassign();
    info.reason = -1;
    info.decision_level = -1;
    _next_unit_propagation = std::min(_next_unit_propagation, (int)_trail.size());

    _unassigned_in_last_backtrack.push_back(to_unassign);

@@ -621,86 +671,18 @@ void Engine::clear_trail(TrailIdx from)
    }
}

Engine::ConflictResult Engine::try_propagating_last_clause()
void Engine::set_up_watch(ClauseIdx clause_idx, bool first)
{
    _try_propagating_last = false;
    const int last_clause_idx = _clause_infos.size() - 1;
    const Clause& last_clause = _clause_infos.back().clause;
    const Literal& lit = get_clause(clause_idx).body[first ? 0 : 1];

    if (last_clause.size() == 1)
    if (lit.is_negated())
    {
        Literal lit = last_clause.body[0];
        AssignmentState status = literal_status(lit);

        if (status.is_false())
        {
            return ConflictResult::CONFLICT;
        _variable_infos[lit.get_variable()].watched_negated.push_back(clause_idx);
    }

        if (status.is_unknown())
        {
            set_literal_to_true(lit);
            // _reasons[lit.get_variable()] = last_clause_idx;
            _variable_infos[lit.get_variable()].reason = last_clause_idx;
        }

        return ConflictResult::OK;
    }

    std::optional<Literal> propagable = std::nullopt;

    if (literal_status(last_clause.body[0]).is_true() ||
        literal_status(last_clause.body[1]).is_true())
    {
        return ConflictResult::OK;
    }

    if (literal_status(last_clause.body[0]).is_false() &&
        literal_status(last_clause.body[1]).is_false())
    {
        return ConflictResult::CONFLICT;
    }

    if (variable_status(last_clause.body[0].get_variable()).is_unknown())
    {
        propagable = last_clause.body[0];
    }

    if (variable_status(last_clause.body[1].get_variable()).is_unknown())
    {
        if (propagable.has_value())
    else
    {
            // we cannot unit propagate, as both watches are unknown
            return ConflictResult::OK;
        }

        propagable = last_clause.body[1];
    }

    assert(propagable.has_value());

    set_literal_to_true(propagable.value());
    _variable_infos[propagable.value().get_variable()].reason = last_clause_idx;

    return ConflictResult::OK;
        _variable_infos[lit.get_variable()].watched_nonnegated.push_back(clause_idx);
    }

// TODO: The watch invariants can probably still be improved
void Engine::set_watch(ClauseIdx clause_idx, int variable_idx, bool first)
{
    Clause& clause = _clause_infos[clause_idx].clause;

    // a clause with just one variable has no watches!
    assert(clause.size() != 1);

    Literal& new_position = first ? clause.body[0] : clause.body[1];

    std::swap(new_position, clause.body[variable_idx]);

    auto& variable_info = _variable_infos[new_position.get_variable()];
    auto& _watched = (new_position.is_negated()) ? variable_info.watched_negated
                                                 : variable_info.watched_nonnegated;
    _watched.push_back({clause_idx, first});
}

Engine::ReassignWatchResult Engine::reassign_watch(ClauseIdx clause_idx, bool first)
@@ -725,10 +707,15 @@ Engine::ReassignWatchResult Engine::reassign_watch(ClauseIdx clause_idx, bool fi
    {
        if (!literal_status(clause.body[i]).is_false())
        {
            set_watch(clause_idx, i, first);
            // set_watch(clause_idx, i, first);
            return ReassignWatchResult::REASSIGNED;
        }
    }

    if (literal_status(other_watched).is_false())
    {
        return ReassignWatchResult::CONFLICT;
    }

    return ReassignWatchResult::COULD_NOT;
}
+7 −7
Original line number Diff line number Diff line
@@ -238,7 +238,7 @@ SolverResult Solver::run_CDCL()

        while (unit_propagation_result.has_value())
        {
            if (_engine.current_decision_level() == 0)
            if (_engine.can_backtrack())
            {
                return SolverResult::UNSAT;
            }
@@ -248,12 +248,6 @@ SolverResult Solver::run_CDCL()
            // we can always unit propagate the learnt clause somewhere
            auto new_level = _engine.lowest_unit_propagation_level(learnt).value();

            _engine.learn(learnt);
            _forgetting_manager->clause_added();

            _restart_manager->register_conflict();
            _forgetting_manager->register_conflict();

            // this way we restart right after the number of conflicts is reached
            if (_restart_manager->should_restart())
            {
@@ -271,6 +265,12 @@ SolverResult Solver::run_CDCL()
                _engine.backtrack_to(new_level);
            }

            _engine.learn(learnt);
            _forgetting_manager->clause_added();

            _restart_manager->register_conflict();
            _forgetting_manager->register_conflict();

            _variable_picker->finish_conflict_learning(_engine,
                                                       _engine.current_decision_level());