Commit f658ebb8 authored by Vladimír Uhlík's avatar Vladimír Uhlík
Browse files

src: Rewrite conflict analysis according to miniSAT implementation.

parent 167f2c9c
Loading
Loading
Loading
Loading
+49 −81
Original line number Diff line number Diff line
@@ -6,6 +6,7 @@
#include <cassert>
#include <cmath>
#include <cstdint>
#include <cstring>
#include <optional>
#include <set>
#include <utility>
@@ -254,11 +255,6 @@ namespace zilla
            return is_unassigned( left )  ? *left :
                   is_unassigned( right ) ? *right : nullptr;
        }

        bool contains_lit( literal *needle ) const
        {
            return std::binary_search( lits.begin(), lits.end(), needle, lit_cmp );
        }
    };

    struct formula
@@ -272,22 +268,29 @@ namespace zilla
        vec<     idx_t > trail;
        vec<     idx_t > decisions;
        vec< literal * > to_resolve;
        vec<     idx_t > seen;

        idx_t timestamp;
        double act_const;
        heap< variable > var_prior;

        formula( idx_t vars, idx_t cls ) :
            nvariables( vars ),
            nclauses( cls ),
            timestamp( 0 ),
            act_const( 1 ),
            var_prior( vars, variables )
        {
            variables.resize( nvariables );
            seen.resize( nvariables );
            literals.resize( nvariables * 2 );
            clauses.reserve( nclauses * 2 );

            for ( var_t i = 0; i < nvariables; ++ i )
            {
                variables[ i ].set_var( i + 1 );
                seen[ i ] = 0;
            }
        }

        formula( formula && )                 = default;
@@ -410,7 +413,7 @@ namespace zilla

        void backtrack( idx_t backtracklevel )
        {
            while ( decisions.back() > backtracklevel )
            while ( decisions.size() > 1 && decisions.back() > backtracklevel )
                decisions.pop_back();

            while ( decisions.back() != trail.size() )
@@ -458,93 +461,58 @@ namespace zilla
            return std::make_pair( SatState::UNSAT, 0 );
        }

        std::pair< clause, idx_t > conflict_analysis( idx_t c_idx )
        std::pair< vec< literal * >, idx_t > conflict_analysis( idx_t c_idx )
        {
            clause c = clauses[ c_idx ];
            vec< literal * > learnt;
            int conf_path = 0;
            literal *p = nullptr;

            auto dl_count = [ this, &c ]()
            {
            if ( decisions.empty() )
                    return 1u;
                return std::make_pair( std::move( learnt ), max_idx );

                unsigned count = 0;
                idx_t level = decisions.back();
            idx_t index = trail.size() - 1;
            idx_t bt_level = 0;
            ++ timestamp;

                for ( auto *lit : c.lits )
            do
            {
                    if ( lit->dual->get_level() >= level && lit->dual->get_level() != max_level )
                        ++ count;
                    if ( count >= 2 )
                        return 2u;
                }
                clause &c = clauses[ c_idx ];

                return count;
            };

            auto highest = [ this, &c ]( auto it )
                for ( auto *lit : c.lits )
                {
                if ( decisions.empty() )
                    return max_level;

                for ( ; it != trail.rend() && !c.contains_lit( literals[ *it ].dual ); ++ it );

                auto trail_idx = static_cast< idx_t >( trail.rend() - it );
                return trail_idx >= decisions.back() ? decisions.back() :
                       literals[ trail[ trail_idx ] ].get_level();
            };
                    if ( lit == p )
                        continue;

            if ( dl_count() == 1 )
                return std::make_pair( std::move( c ), highest( trail.rbegin() ) );
                    var_t var = ( *lit ).get_var().get_var() - 1;

            for ( auto it = trail.rbegin(); it != trail.rend(); ++ it )
                    if ( seen[ var ] != timestamp && lit->dual->get_level() != max_idx )
                    {
                if ( c.contains_lit( literals[ *it ].dual ) )
                {
                    resolve( *it, c, clauses[ literals[ *it ].reason.value() ] );

                    auto &var = literals[ *it ].get_var();
                    var_prior.bump_activity( var.get_var() - 1, act_const );
                        seen[ var ] = timestamp;
                        var_prior.bump_activity( var, act_const );
                        update_act();

                        if ( lit->dual->get_level() == decisions.back() )
                            ++ conf_path;
                        else
                        {
                            learnt.push_back( lit );
                            auto lvl = lit->get_level();
                            if ( lvl != max_idx && lvl > bt_level )
                                bt_level = lvl;
                        }
                if ( dl_count() == 1 )
                    return std::make_pair( std::move( c ), highest( ++ it ) );
                    }

            std::unreachable();
                }

        void resolve( idx_t rlit, auto &prem1, const auto &prem2 )
        {
            vec< literal * > resolvent;
            resolvent.reserve( prem1.lits.size() + prem2.lits.size() );
            auto it_p1 = prem1.lits.begin();
            auto it_p2 = prem2.lits.begin();
            auto rem_from_p1 =  literals[ rlit ].dual;
            auto rem_from_p2 = &literals[ rlit ];

            auto append = [ &resolvent ]( auto &it, auto to_remove )
            {
                if ( *it != to_remove && ( resolvent.empty() || *it != resolvent.back() ) )
                    resolvent.push_back( *it );
                ++ it;
            };
            auto finish = [ &append ]( auto it, auto &lits, auto to_remove )
            {
                for ( ; it != lits.end(); )
                    append( it, to_remove );
            };
                while ( seen[ literals[ trail[ index -- ] ].get_var().get_var() - 1 ] != timestamp );

            while ( it_p1 != prem1.lits.end() && it_p2 != prem2.lits.end() )
            {
                if ( **it_p1 < **it_p2 )
                    append( it_p1, rem_from_p1 );
                else
                    append( it_p2, rem_from_p2 );
                p = &literals[ trail[ index + 1 ] ];
                c_idx = p->reason;
            }
            finish( it_p1, prem1.lits, rem_from_p1 );
            finish( it_p2, prem2.lits, rem_from_p2 );
            resolvent.shrink_to_fit();
            prem1.lits = std::move( resolvent );
            while ( -- conf_path > 0 );

            assert( p != nullptr );
            learnt.push_back( &p->get_dual() );
            return std::make_pair( std::move( learnt ), bt_level );
        }
    };
}