Skip to content
Snippets Groups Projects
Commit 77340130 authored by Giovanni Bussi's avatar Giovanni Bussi
Browse files

Optimized lepton switching function

It is not possible to use x2 as a variable. In this case,
sqrt of distance won't be computed.
parent 0e280f74
No related branches found
No related tags found
No related merge requests found
......@@ -224,6 +224,7 @@ void SwitchingFunction::set(const std::string & definition,std::string& errormsg
present=Tools::findKeyword(data,"MM");
if(present && !Tools::parse(data,"MM",mm)) errormsg="could not parse MM";
if(mm==0) mm=2*nn;
fastrational=(nn%2==0 && mm%2==0 && d0==0.0);
} else if(name=="SMAP") {
type=smap;
present=Tools::findKeyword(data,"A");
......@@ -262,18 +263,25 @@ void SwitchingFunction::set(const std::string & definition,std::string& errormsg
try {
lepton_ref[t]=&const_cast<lepton::CompiledExpression*>(&expression[t])->getVariableReference("x");
} catch(PLMD::lepton::Exception& exc) {
try {
lepton_ref[t]=&const_cast<lepton::CompiledExpression*>(&expression[t])->getVariableReference("x2");
leptonx2=true;
} catch(PLMD::lepton::Exception& exc) {
// this is necessary since in some cases lepton things a variable is not present even though it is present
// e.g. func=0*x
lepton_ref[t]=nullptr;
lepton_ref[t]=nullptr;
}
}
}
lepton::ParsedExpression ped=lepton::Parser::parse(func).differentiate("x").optimize(leptonConstants);
std::string arg="x";
if(leptonx2) arg="x2";
lepton::ParsedExpression ped=lepton::Parser::parse(func).differentiate(arg).optimize(leptonConstants);
expression_deriv.resize(OpenMP::getNumThreads());
for(auto & e : expression_deriv) e=ped.createCompiledExpression();
lepton_ref_deriv.resize(expression_deriv.size());
for(unsigned t=0; t<lepton_ref_deriv.size(); t++) {
try {
lepton_ref_deriv[t]=&const_cast<lepton::CompiledExpression*>(&expression_deriv[t])->getVariableReference("x");
lepton_ref_deriv[t]=&const_cast<lepton::CompiledExpression*>(&expression_deriv[t])->getVariableReference(arg);
} catch(PLMD::lepton::Exception& exc) {
// this is necessary since in some cases lepton things a variable is not present even though it is present
// e.g. func=3*x
......@@ -295,6 +303,8 @@ void SwitchingFunction::set(const std::string & definition,std::string& errormsg
stretch=1.0/(s0-sd);
shift=-sd*stretch;
}
plumed_assert(!(leptonx2 && d0!=0.0)) << "You cannot use lepton x2 optimization with d0!=0.0 (d0=" << d0 <<")\n"
<< "Please rewrite your function using x as a variable";
}
std::string SwitchingFunction::description() const {
......@@ -361,7 +371,7 @@ double SwitchingFunction::do_rational(double rdist,double&dfunc,int nn,int mm)co
}
double SwitchingFunction::calculateSqr(double distance2,double&dfunc)const {
if(type==rational && nn%2==0 && mm%2==0 && d0==0.0) {
if(fastrational) {
if(distance2>dmax_2) {
dfunc=0.0;
return 0.0;
......@@ -370,6 +380,24 @@ double SwitchingFunction::calculateSqr(double distance2,double&dfunc)const {
double result=do_rational(rdist_2,dfunc,nn/2,mm/2);
// chain rule:
dfunc*=2*invr0_2;
// stretch:
result=result*stretch+shift;
dfunc*=stretch;
return result;
} else if(leptonx2) {
if(distance2>dmax_2) {
dfunc=0.0;
return 0.0;
}
const unsigned t=OpenMP::getThreadNum();
const double rdist_2 = distance2*invr0_2;
plumed_assert(t<expression.size());
if(lepton_ref[t]) *lepton_ref[t]=rdist_2;
if(lepton_ref_deriv[t]) *lepton_ref_deriv[t]=rdist_2;
double result=expression[t].evaluate();
dfunc=expression_deriv[t].evaluate();
// chain rule:
dfunc*=2*invr0_2;
// stretch:
result=result*stretch+shift;
dfunc*=stretch;
......@@ -386,6 +414,11 @@ double SwitchingFunction::calculate(double distance,double&dfunc)const {
dfunc=0.0;
return 0.0;
}
// in this case, the lepton object stores only the calculateSqr function
// so we have to implement calculate in terms of calculateSqr
if(leptonx2) {
return calculateSqr(distance*distance,dfunc);
}
const double rdist = (distance-d0)*invr0;
double result;
......@@ -451,6 +484,8 @@ void SwitchingFunction::set(int nn,int mm,double r0,double d0) {
this->d0=d0;
this->dmax=d0+r0*pow(0.00001,1./(nn-mm));
this->dmax_2=this->dmax*this->dmax;
this->leptonx2=false;
this->fastrational=(nn%2==0 && mm%2==0 && d0==0.0);
double dummy;
double s0=calculate(0.0,dummy);
......
......@@ -83,6 +83,10 @@ class SwitchingFunction {
std::vector<lepton::CompiledExpression> expression_deriv;
std::vector<double*> lepton_ref;
std::vector<double*> lepton_ref_deriv;
/// Set to true for fast rational functions (depending on x**2 only)
bool fastrational=false;
/// Set to true if lepton only uses x2
bool leptonx2=false;
public:
static void registerKeywords( Keywords& keys );
/// Set a "rational" switching function.
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment