Commit 75121571 authored by Jindřich Sedláček's avatar Jindřich Sedláček
Browse files

feat(unit propagation): finish refactor, speed up forgetting

parent 8aa775bb
Loading
Loading
Loading
Loading
+4 −20
Original line number Diff line number Diff line
@@ -59,7 +59,7 @@ private:
    DecisionLevel _highest_backtrack_level = NO_LEVEL;
    TrailIdx _earliest_flipped_decision = INT_MAX;

    // Invariant: after a unit propagation, this is equal to _trail.size()
    // Invariant: after a successful unit propagation, this is equal to _trail.size()
    // Invariant: this is always less or equal to _trail.size()
    TrailIdx _next_unit_propagation = 0;

@@ -115,9 +115,9 @@ public:
    // Invariant: the indices of the clauses that the object was constructed
    // with does not change
    //
    // Returns relocations of the clauses, that is, the new locations of
    // clauses; the value is -1 if the clause was deleted
    std::vector<ClauseIdx> forget(std::vector<ClauseIdx> indices);
    // Returns relocations of the clauses, that is, for each clause index
    // its new clause index after relocation; -1 means the clause was removed
    const std::vector<ClauseIdx> &forget(std::vector<ClauseIdx> indices);

    void reset();
    void backtrack_to(DecisionLevel level);
@@ -133,21 +133,5 @@ private:
    void unassign_last();
    void clear_trail(TrailIdx from);

    enum class ConflictResult
    {
        CONFLICT,
        OK
    };

    void set_up_watch(ClauseIdx clause_idx, bool first);

    enum class ReassignWatchResult
    {
        COULD_NOT,
        KEPT_THE_SAME,
        REASSIGNED,
        CONFLICT,
    };

    ReassignWatchResult reassign_watch(ClauseIdx clause_idx, bool first);
};
+67 −170
Original line number Diff line number Diff line
@@ -8,6 +8,7 @@ Engine::Engine(const Formula& formula)
{
    _decisions.reserve(formula.num_variables);
    _trail.reserve(formula.num_variables);
    _unassigned_in_last_backtrack.reserve(formula.num_variables);

    for (ClauseIdx idx = 0; idx < formula.num_clauses; idx++)
    {
@@ -16,6 +17,9 @@ Engine::Engine(const Formula& formula)
        if (formula.body[idx].size() == 1)
        {
            set_literal_to_true(formula.body[idx].body[0]);
            // one-literal clauses are technically never used for learning;
            // this exists just in case it changes in the future
            _variable_infos[formula.body[idx].body[0].get_variable()].reason = idx;
        }
        else
        {
@@ -42,21 +46,19 @@ const Clause& Engine::get_clause(ClauseIdx idx) const

AssignmentState Engine::clause_status(ClauseIdx idx) const
{
    assert(_next_unit_propagation == _trail.size());
    const Clause& clause = get_clause(idx);

    const Literal first_literal = clause.body[0];
    const Literal second_literal = clause.body[1];

    if (literal_status(first_literal).is_false() &&
        literal_status(second_literal).is_false())
    if (literal_status(clause.body[0]).is_false() &&
        literal_status(clause.body[1]).is_false())
    {
        // if both watched literals are false, then the whole clause is false - otherwise,
        // they would have been reassigned
        return AssignmentState::FALSE;
    }

    if (literal_status(first_literal).is_true() ||
        literal_status(second_literal).is_true())
    if (literal_status(clause.body[0]).is_true() ||
        literal_status(clause.body[1]).is_true())
    {
        return AssignmentState::TRUE;
    }
@@ -147,14 +149,9 @@ std::optional<Engine::ClauseIdx> Engine::reason(Variable variable) const
{
    assert(variable > 0);
    assert(variable < _variable_infos.size());
    auto reason = _variable_infos[variable].reason;

    if (reason == NO_REASON)
    {
        return std::nullopt;
    }

    return reason;
    auto reason = _variable_infos[variable].reason;
    return (reason == NO_REASON) ? std::nullopt : std::optional(reason);
}

std::optional<Engine::DecisionLevel>
@@ -285,16 +282,10 @@ void Engine::learn(Clause clause)
    if (info.clause.size() == 1)
    {
        set_literal_to_true(info.clause.body[0]);
        _variable_infos[info.clause.body[0].get_variable()].reason = idx;
        return;
    }

    // what we want here ... we know this clause is not false, so
    // it has at least one literal that is not false
    //
    // if it can be unit propagated, it is unit propagated
    //
    // the watches are set such that decisions and backtracks remain valid

    // we order the literals in the clause such that true literals are
    // first, unknown second, false last

@@ -320,153 +311,88 @@ void Engine::learn(Clause clause)
        }
    }

    // we can unit propagate the unknown (now first) literal
    if (true_idx == 0 && unknown_idx == 1)
    {
        // we can unit propagate the one unknown (now first) literal
        _variable_infos[info.clause.body[0].get_variable()].reason = idx;
        set_literal_to_true(info.clause.body[0]);
    }
    else
    {
        // our solver should never get here as it is written;
        // assert(false);
    }

    set_up_watch(idx, true);
    set_up_watch(idx, false);

    return;

    bool has_true = false;
    int unassigned_count = 0;
    int idx_of_last_unassigned = -1;

    for (int i = 0; i < info.clause.size(); i++)
    {
        Literal& l = info.clause.body[i];
        const AssignmentState status = literal_status(l);

        has_true = has_true || literal_status(l).is_true();

        if (status.is_unknown())
        {
            unassigned_count++;
            idx_of_last_unassigned = i;
        }
    }

    if (!has_true && unassigned_count)
    {
        set_literal_to_true(info.clause.body[idx_of_last_unassigned]);
}

    // a dirty trick...
    const int UNASSIGNED = INT_MAX;

    int candidate_one = -1;
    int candidate_one_level = -1;
    int candidate_two = -1;
    int candidate_two_level = -1;

    // IMPORTANT: the clause has at least two literals here!
    for (int i = 0; i < info.clause.size(); i++)
    {
        Literal lit = info.clause.body[i];
        const int level = decision_level(lit.get_variable()).value_or(UNASSIGNED);

        if (level > candidate_one_level)
const std::vector<Engine::ClauseIdx>& Engine::forget(std::vector<ClauseIdx> indices)
{
            std::swap(candidate_one, candidate_two);
            std::swap(candidate_one_level, candidate_two_level);

            candidate_one = i;
            candidate_one_level = level;
        }
        else if (level > candidate_two_level)
        {
            candidate_two = i;
            candidate_two_level = level;
        }
    }

    Literal candidate_one_literal = info.clause.body[candidate_one];
    Literal candidate_two_literal = info.clause.body[candidate_two];

    if (candidate_one > candidate_two)
    {
        std::swap(candidate_one, candidate_two);
    }
    const int REMOVED = -1;

    int one_destination = 0;
    int two_destination = 1;
    // invariant ... no value is true in this map when this function begins
    static std::vector<bool> remove_map;
    remove_map.resize(clause_count(), false);

    if (candidate_two == 0)
    for (ClauseIdx idx : indices)
    {
        std::swap(candidate_one, candidate_two);
        remove_map[idx] = true;
    }

    std::swap(info.clause.body[one_destination], info.clause.body[candidate_one]);
    std::swap(info.clause.body[two_destination], info.clause.body[candidate_two]);
    // indices ... of type ClauseIdx
    static std::vector<ClauseIdx> relocations;
    relocations.clear();
    relocations.reserve(clause_count());

    set_up_watch(idx, true);
    set_up_watch(idx, false);
}

// NOTE: this method is slow, but it should not happen too often
std::vector<Engine::ClauseIdx> Engine::forget(std::vector<ClauseIdx> indices)
{
    const int REMOVED = -1;

    // static to avoid allocations
    std::vector<ClauseInfo> new_clause_infos;
    new_clause_infos.reserve(_clause_infos.size() - indices.size());

    // indices ... clause indices
    std::vector<ClauseIdx> relocations;

    // Doing this in linear time is worthless, considering how
    // much has to happen in the rest of this method...
    int new_size = 0;
    for (ClauseIdx idx = 0; idx < _clause_infos.size(); idx++)
    {
        if (_clause_infos[idx].clause.size() < 3 || idx < original_clause_count() ||
            std::find(indices.begin(), indices.end(), idx) != indices.end())
            !remove_map[idx])
        {
            new_clause_infos.push_back(_clause_infos[idx]);
            relocations.push_back(new_clause_infos.size() - 1);
            // the move must happen into a local variable first, as loc = std::move[loc]
            // is undefined...
            ClauseInfo moved = std::move(_clause_infos[idx]);
            relocations.push_back(new_size);
            _clause_infos[new_size++] = std::move(moved);
        }
        else
        {
            relocations.push_back(REMOVED);
        }
        remove_map[idx] = false;
    }

    _clause_infos.resize(new_size);

    for (VariableInfo& variable_info : _variable_infos)
    {
        std::vector<ClauseIdx> new_watched_negated = {};
        for (auto& watched_in : variable_info.watched_negated)
        int new_size = 0;
        for (int watch_i = 0; watch_i < variable_info.watched_negated.size(); watch_i++)
        {
            ClauseIdx watched_in = variable_info.watched_negated[watch_i];

            if (relocations[watched_in] == REMOVED)
            {
                continue;
            }

            new_watched_negated.push_back(relocations[watched_in]);
            variable_info.watched_negated[new_size++] = relocations[watched_in];
        }

        variable_info.watched_negated = new_watched_negated;
        variable_info.watched_negated.resize(new_size);

        std::vector<ClauseIdx> new_watched_nonnegated = {};
        for (auto& watched_in : variable_info.watched_nonnegated)
        new_size = 0;
        for (int watch_i = 0; watch_i < variable_info.watched_nonnegated.size();
             watch_i++)
        {
            ClauseIdx watched_in = variable_info.watched_nonnegated[watch_i];

            if (relocations[watched_in] == REMOVED)
            {
                continue;
            }

            new_watched_nonnegated.push_back(relocations[watched_in]);
            variable_info.watched_nonnegated[new_size++] = relocations[watched_in];
        }

        variable_info.watched_nonnegated = new_watched_nonnegated;
        variable_info.watched_nonnegated.resize(new_size);

        if (variable_info.reason != NO_REASON)
        {
@@ -474,7 +400,6 @@ std::vector<Engine::ClauseIdx> Engine::forget(std::vector<ClauseIdx> indices)
        }
    }

    _clause_infos = new_clause_infos;
    return relocations;
}

@@ -556,21 +481,29 @@ std::optional<Engine::ClauseIdx> Engine::unit_propagate()
        auto& watched =
            (variable_state.is_true()) ? info.watched_negated : info.watched_nonnegated;

        static std::vector<ClauseIdx> to_keep;
        to_keep.clear();

        std::optional<ClauseIdx> conflict_clause = std::nullopt;

        for (ClauseIdx idx : watched)
        int new_size = 0;
        for (int watch_i = 0; watch_i < watched.size(); watch_i++)
        {
            ClauseIdx& idx = watched[watch_i];

            // do not bother with reassigning the other watches when
            // we already know there is a conflict
            if (conflict_clause.has_value())
            {
                watched[new_size++] = idx;
                continue;
            }

            Clause& clause = _clause_infos[idx].clause;
            assert(clause.size() > 1);

            // do not waste time reassigning watches when the clause is true
            if (literal_status(clause.body[0]).is_true() ||
                literal_status(clause.body[1]).is_true())
            {
                // do not waste time reassigning watches when the clause is true
                to_keep.push_back(idx);
                watched[new_size++] = idx;
                continue;
            }

@@ -593,14 +526,14 @@ std::optional<Engine::ClauseIdx> Engine::unit_propagate()

            if (!reassigned)
            {
                to_keep.push_back(idx);
                watched[new_size++] = idx;

                if (literal_status(clause.body[other_watch]).is_false())
                {
                    // we cannot return here, some watches may be reassigned
                    // we cannot return here, some watches may already be reassigned
                    conflict_clause = idx;
                }
                else // the other literal is unknown
                else // the other watch is unknown
                {
                    set_literal_to_true(clause.body[other_watch]);
                    _variable_infos[clause.body[other_watch].get_variable()].reason = idx;
@@ -608,7 +541,7 @@ std::optional<Engine::ClauseIdx> Engine::unit_propagate()
            }
        }

        watched = to_keep;
        watched.resize(new_size);
        _next_unit_propagation++;

        if (conflict_clause.has_value())
@@ -634,7 +567,6 @@ void Engine::assign(Variable variable, VariablePhase phase)
    info.decision_level = _decisions.size();

    // we do not change watches here anymore, that happens in unit propagation
    return;
}

void Engine::set_literal_to_true(Literal lit) { assign(lit.get_variable(), lit.phase()); }
@@ -693,38 +625,3 @@ void Engine::set_up_watch(ClauseIdx clause_idx, bool first)
        _variable_infos[lit.get_variable()].watched_nonnegated.push_back(clause_idx);
    }
}

Engine::ReassignWatchResult Engine::reassign_watch(ClauseIdx clause_idx, bool first)
{
    Clause& clause = _clause_infos[clause_idx].clause;
    assert(clause.size() != 1);

    Literal& this_watched = first ? clause.body[0] : clause.body[1];
    Literal& other_watched = first ? clause.body[1] : clause.body[0];

    assert(literal_status(this_watched).is_false());

    if (literal_status(other_watched).is_true())
    {
        // reasoning: the false watch was assigned later than the true watch,
        // so if the clause stops being true, then both the false watch
        // and the true watch will be reassigned
        return ReassignWatchResult::KEPT_THE_SAME;
    }

    for (int i = 2; i < clause.size(); i++)
    {
        if (!literal_status(clause.body[i]).is_false())
        {
            // set_watch(clause_idx, i, first);
            return ReassignWatchResult::REASSIGNED;
        }
    }

    if (literal_status(other_watched).is_false())
    {
        return ReassignWatchResult::CONFLICT;
    }

    return ReassignWatchResult::COULD_NOT;
}
+1 −1
Original line number Diff line number Diff line
@@ -256,7 +256,7 @@ SolverResult Solver::run_CDCL()
                if (_forgetting_manager->time_to_forget())
                {
                    auto to_forget = _forgetting_manager->clauses_to_forget();
                    auto relocations = _engine.forget(std::move(to_forget));
                    auto relocations = _engine.forget(to_forget);
                    _forgetting_manager->register_restructuring(std::move(relocations));
                }
            }
+1 −1
Original line number Diff line number Diff line
@@ -528,7 +528,7 @@ int main(int argc, char* argv[])
    auto restart_method = Solver::RestartMethod::LUBY;
    const int restart_constant = 2;

    Solver solver(formula.value(), options.get_solver_options());
    Solver solver(std::move(formula.value()), options.get_solver_options());

    SolverResult res = solver.run();
    std::vector<Literal> model;