From 685b065b5b3ee3f01994cab38aa98327582bb0ec Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luk=C3=A1=C5=A1=20Koren=C4=8Dik?= <xkorenc1@fi.muni.cz> Date: Thu, 30 Jan 2020 15:11:48 +0000 Subject: [PATCH] lart: Add pass that lowers ret type of functions from aggr value to pointer. --- lart/mcsema/lowerreturn.hpp | 378 ++++++++++++++++++++++++++++++++++++ 1 file changed, 378 insertions(+) create mode 100644 lart/mcsema/lowerreturn.hpp diff --git a/lart/mcsema/lowerreturn.hpp b/lart/mcsema/lowerreturn.hpp new file mode 100644 index 000000000..ae74a053c --- /dev/null +++ b/lart/mcsema/lowerreturn.hpp @@ -0,0 +1,378 @@ +/* + * (c) 2020 Lukáš Korenčik <xkorenc1@fi.muni.cz> + * + * Permission to use, copy, modify, and distribute this software for any + * purpose with or without fee is hereby granted, provided that the above + * copyright notice and this permission notice appear in all copies. + * + * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES + * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF + * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR + * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES + * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN + * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF + * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + */ + + +#pragma once + +#include <iostream> +#include <vector> +#include <unordered_map> + +DIVINE_RELAX_WARNINGS +#include <llvm/IR/Instructions.h> +#include <llvm/IR/Module.h> +DIVINE_UNRELAX_WARNINGS + +#include <brick-llvm> + +#include <lart/abstract/util.h> +#include <lart/support/meta.h> +#include <lart/support/query.h> +#include <lart/support/util.h> + +namespace lart::mcsema +{ + + struct lower_ret_agg : abstract::LLVMUtil< lower_ret_agg > + { + using values_t = std::vector< llvm::Value * >; + + using functions_t = std::vector< llvm::Function * >; + using functions_map_t = std::unordered_map< llvm::Function *, llvm::Function * >; + using agg_to_wrapper_t = std::unordered_map< llvm::Function *, llvm::Function * >; + using types_t = std::vector< llvm::Type * >; + + constexpr static const char *wrapper_prefix = "divine.ret.wrapper"; + + llvm::Module *_m; + const llvm::DataLayout &_dl; + llvm::LLVMContext &context; + + bool is_lifted( llvm::Function &f ) + { + auto md = f.getMetadata( "remill.function.type" ); + if ( !md ) + return false; + if ( md->getNumOperands() != 1 ) + UNREACHABLE( "Lifted bc has invalid remill.function.type annotation"); + + auto md_str = llvm::dyn_cast< llvm::MDString >( md->getOperand( 0 ) ); + + if ( !md_str ) + UNREACHABLE( "remill.function.type annotation has incorrect op type"); + + return md_str->getString().contains( "lifted" ) || + md_str->getString().contains( "helper.mcsema" ); + } + + functions_t lifted_funcs() + { + functions_t out; + for ( auto &f : *_m ) + if ( is_lifted( f ) ) + out.push_back( &f ); + return out; + } + + auto wrap_ret_t( llvm::Function *f ) { return ptr( f->getReturnType() ); } + + functions_map_t change_ret_type() + { + functions_map_t changed; + for ( auto f : lifted_funcs() ) + changed.insert( { f, changeReturnType( f, wrap_ret_t( f ) ) } ); + return changed; + } + + + void flatten( llvm::Type *type, types_t &result ) + { + if ( util::is_one_of_types< llvm::PointerType, llvm::IntegerType >( type ) ) + { + result.push_back( type ); + return; + } + + if ( auto struct_t = llvm::dyn_cast< llvm::StructType >( type ) ) + { + for ( auto e : struct_t->elements() ) + flatten( e, result ); + return; + } + UNREACHABLE( "Cannot flatten this type" ); + } + + std::string next_name() + { + static uint64_t counter = 0; + return std::string( wrapper_prefix ) + std::to_string( ++counter ); + } + + // FIXME: Currently we leak memory + template< typename irb_t > + void free( llvm::Value * val, irb_t &irb ) + { + UNREACHABLE( "Not implemented" ); + } + + template< typename irb_t > + llvm::Value *allocate( llvm::Type * type, irb_t &irb ) + { + // Currently we use malloc + auto malloc_f = _m->getFunction( "malloc" ); + if ( !malloc_f ) + UNREACHABLE( "Could not find malloc while lowering struct" ); + + auto memory = irb.CreateCall( malloc_f, i64( _dl.getTypeAllocSize( type ) ) ); + return irb.CreateBitCast( memory, ptr( type ) ); + } + + llvm::Function *synthetize_wrapper( llvm::Type * type ) + { + + auto struct_t = llvm::dyn_cast< llvm::StructType >( type ); + ASSERT( struct_t ); + + types_t flattened; + flatten( struct_t, flattened ); + + auto wrapper_t = llvm::FunctionType::get( ptr( type ), flattened, false ); + auto wrapper_fc = _m->getOrInsertFunction( next_name(), wrapper_t ); + auto wrapper_f = llvm::dyn_cast< llvm::Function >( wrapper_fc ); + + auto entry = llvm::BasicBlock::Create( context, "entry" , wrapper_f ); + + llvm::IRBuilder<> irb( entry ); + auto memory = allocate( type, irb ); + + for ( auto i = 0U; i < struct_t->getNumElements(); ++i ) + { + auto gep = irb.CreateInBoundsGEP( memory, { i32( 0 ), i32( i ) } ); + irb.CreateStore( argument( wrapper_f, i ), gep ); + } + irb.CreateRet( memory ); + return wrapper_f; + } + + agg_to_wrapper_t synthetize_wrappers() + { + agg_to_wrapper_t wrappers; + for ( auto f : lifted_funcs() ) + { + auto ret_t = f->getReturnType(); + // Since there can be already clonned functions with ptr as return type + if ( !ret_t->isPointerTy() ) + continue; + + auto struct_type = + llvm::dyn_cast< llvm::PointerType >( ret_t )->getElementType(); + auto &wrapper = wrappers[ f ]; + if ( !wrapper ) + { + wrapper = synthetize_wrapper( struct_type ); + } + } + return wrappers; + } + + + struct replacer + { + + void replace_pack() + { + for ( auto [original, fn ] : functions ) + replace_pack( *fn ); + } + + template< typename inst_t, typename yield_t > + bool walk( llvm::Value *val, yield_t yield ) + { + if ( !val ) return false; + auto inst = llvm::dyn_cast< inst_t >( val ); + if ( !inst ) return false; + + while ( inst ) + { + yield( inst ); + inst = llvm::dyn_cast< inst_t >( inst->getAggregateOperand() ); + } + return true; + } + + values_t get( llvm::Function &f ) + { + auto struct_t = get_original_type( &f ); + return values_t( struct_t->getNumElements(), nullptr ); + } + + llvm::StructType *get_original_type( llvm::Instruction *inst ) + { + return get_original_type( inst->getParent()->getParent() ); + } + + llvm::StructType *get_original_type( llvm::Function *f ) + { + auto ptr_t = llvm::dyn_cast< llvm::PointerType >( f->getReturnType() ); + return llvm::dyn_cast< llvm::StructType >( ptr_t->getElementType() ); + } + + void replace_returns( llvm::Instruction *ret, values_t args ) + { + auto ret_t = get_original_type( ret ); + + auto wrapper = wrappers.find( ret->getParent()->getParent() )->second; + + ASSERT( ret_t->getNumElements() == args.size()); + + llvm::IRBuilder<> irb( ret ); + auto memory = irb.CreateCall( wrapper, args ); + irb.CreateRet( memory ); + + ret->eraseFromParent(); + + } + + void replace_pack( llvm::Function &f ) + { + auto rets = query::query( f ) + .flatten() + .filter( query::llvmdyncast< llvm::ReturnInst > ) + .map( query::refToPtr ) + .freeze(); + + values_t args; + values_t erasable; + auto collect = [ & ]( llvm::InsertValueInst *insertvalue ) { + auto idx = *insertvalue->idx_begin(); + auto val = insertvalue->getInsertedValueOperand(); + args[ idx ] = val; + erasable.push_back( insertvalue ); + }; + + for ( auto ret : rets ) + { + args = get( f ); + walk< llvm::InsertValueInst >( ret->getOperand( 0 ), collect ); + + replace_returns( ret, std::move( args ) ); + + // Erase the insertvalues + erase( std::move( erasable ) ); + } + } + + template< typename V > + void erase( std::vector< V > erasable ) + { + for ( auto v : erasable ) + { + auto inst = llvm::cast< llvm::Instruction >( v ); + inst->replaceAllUsesWith( llvm::UndefValue::get( inst->getType() ) ); + inst->eraseFromParent(); + } + } + + void replace_calls() + { + std::vector< llvm::Instruction *> erasable; + for ( auto [ original, func ] : functions ) + { + for ( auto user : original->users() ) + { + if ( auto call = llvm::dyn_cast< llvm::CallInst >( user ) ) + { + fix( call, func ); + erasable.push_back( call ); + continue; + } + std::cerr << "\nUnsupported fix" << std::endl; + user->print(llvm::errs()); + } + erase( std::move( erasable ) ); + } + } + + llvm::CallInst *rewire( llvm::CallSite old_cs, llvm::Function *new_f ) + { + llvm::IRBuilder<> irb( old_cs.getInstruction() ); + return irb.CreateCall( + new_f, + values_t( old_cs.arg_begin(), old_cs.arg_end() ) ); + } + + void fix( llvm::CallInst *old_call, llvm::Function *new_f ) + { + auto n_call = rewire( { old_call }, new_f ); + + std::unordered_map< llvm::Instruction *, uint64_t > extracts; + values_t erasable; + auto collect = [ & ]( llvm::ExtractValueInst *extractvalue ) { + auto idx = *extractvalue->idx_begin(); + extracts[ extractvalue ] = idx; + erasable.push_back( extractvalue ); + }; + + for ( auto user : old_call->users() ) + { + walk< llvm::ExtractValueInst >( user, collect ); + for ( auto [ val, idx ] : extracts ) + { + llvm::IRBuilder<> irb( val ); + auto gep = irb.CreateGEP( n_call, + { pass.i32( 0 ), pass.i32( idx ) } ); + auto load = irb.CreateLoad( gep ); + val->replaceAllUsesWith( load ); + } + } + erase( std::move( erasable ) ); + } + + void replace_funcs() + { + for ( auto[ original, func ] : functions ) + { + auto name = original->getName(); + original->eraseFromParent(); + func->setName( name ); + } + } + + void replace() + { + replace_pack(); + replace_calls(); + replace_funcs(); + } + + replacer( functions_map_t &functions_, + const agg_to_wrapper_t &wrappers_, + lower_ret_agg &pass_ ) + : functions( functions_ ), + wrappers( wrappers_ ), + pass( pass_ ) + {} + + functions_map_t &functions; + const agg_to_wrapper_t &wrappers; + lower_ret_agg &pass; + }; + + lower_ret_agg( llvm::Module &m ) : _m( &m ), + _dl( m.getDataLayout() ), + context( m.getContext() ) {} + + void run() + { + auto twins = change_ret_type(); + auto wrappers = synthetize_wrappers(); + replacer( twins, wrappers, *this ).replace(); + brick::llvm::verifyModule( _m ); + } + + }; + +} -- GitLab