Commit 167f2c9c authored by Vladimír Uhlík's avatar Vladimír Uhlík
Browse files

src: Fix broken restart scheme and optimize ldb computation a bit.

parent f27cdf64
Loading
Loading
Loading
Loading
+31 −14
Original line number Diff line number Diff line
#pragma once

#include "types.hpp"
#include <unordered_set>
#include <cstring>

using zilla::SatState;
using zilla::clause;
using zilla::formula;
using zilla::idx_t;
using zilla::max_idx;
using zilla::SatState;
using zilla::variable;

struct solver
@@ -13,8 +15,14 @@ struct solver
    formula &f;
    std::uint64_t short_ema = 0;
    std::uint64_t long_ema  = 0;
    std::vector< idx_t > levels;
    idx_t level_stamp;

    solver( formula &_f ) : f( _f ) {}
    solver( formula &_f ) : f( _f ), level_stamp( 0 )
    {
        levels.resize( f.nvariables );
        std::memset( levels.data(), 0, levels.size() * sizeof( idx_t ) );
    }

    template< std::uint64_t scale, std::uint64_t mult >
    void update_ema( std::uint64_t &ema, std::uint64_t ldb )
@@ -22,14 +30,22 @@ struct solver
        ema = ( ldb << ( 32ull - scale ) ) + ( ( mult * ema ) >> scale );
    }

    std::uint64_t compute_ldb( const clause &c )
    std::uint64_t compute_ldb( const auto &v )
    {
        std::unordered_set< zilla::idx_t > unique;
        std::uint64_t result = 0;
        ++ level_stamp;

        for ( auto *l : c.lits )
            unique.insert( l->level );
        for ( const auto *l : v )
        {
            auto lvl = l->get_level();
            if ( lvl != max_idx && levels[ lvl ] != level_stamp )
            {
                levels[ lvl ] = level_stamp;
                ++ result;
            }
        }

        return unique.size();
        return result;
    }

    SatState solve_cdcl()
@@ -51,22 +67,23 @@ struct solver
                if ( result != SatState::CONFLICT )
                    break;

                auto [ learnt, backtracklevel ] = f.conflict_analysis( c_idx );
                if ( backtracklevel == zilla::max_level )
                auto [ learnt_lits, backtracklevel ] = f.conflict_analysis( c_idx );
                if ( backtracklevel == max_idx )
                    return SatState::UNSAT;

                f.learn( learnt );
                f.backtrack( backtracklevel );

                auto ldb = compute_ldb( learnt );
                auto ldb = compute_ldb( learnt_lits );
                update_ema<  5,   31 >( short_ema, ldb );
                update_ema< 12, 4095 >(  long_ema, ldb );

                f.learn( learnt_lits );

                if ( ++ conflicts >= 50 && short_ema > 1.15 * long_ema )
                {
                    f.restart();
                    conflicts = 0;
                }
                else
                    f.backtrack( backtracklevel );
            }
        }
        return SatState::SAT;