Skip to content
Snippets Groups Projects
Commit ebef6841 authored by Omar Valsson's avatar Omar Valsson
Browse files

lepton performance fixes for BF_CUSTOM

parent 134d25d9
No related branches found
No related tags found
No related merge requests found
......@@ -106,8 +106,12 @@ class BF_Custom : public BasisFunctions {
private:
lepton::CompiledExpression transf_value_expression_;
lepton::CompiledExpression transf_deriv_expression_;
double* transf_value_lepton_ref_;
double* transf_deriv_lepton_ref_;
std::vector<lepton::CompiledExpression> bf_values_expressions_;
std::vector<lepton::CompiledExpression> bf_derivs_expressions_;
std::vector<double*> bf_values_lepton_ref_;
std::vector<double*> bf_derivs_lepton_ref_;
std::string variable_str_;
std::string transf_variable_str_;
bool do_transf_;
......@@ -134,8 +138,12 @@ void BF_Custom::registerKeywords(Keywords& keys) {
BF_Custom::BF_Custom(const ActionOptions&ao):
PLUMED_VES_BASISFUNCTIONS_INIT(ao),
transf_value_lepton_ref_(nullptr),
transf_deriv_lepton_ref_(nullptr),
bf_values_expressions_(0),
bf_derivs_expressions_(0),
bf_values_lepton_ref_(0,nullptr),
bf_derivs_lepton_ref_(0,nullptr),
variable_str_("x"),
transf_variable_str_("t"),
do_transf_(false),
......@@ -172,6 +180,8 @@ BF_Custom::BF_Custom(const ActionOptions&ao):
//
bf_values_expressions_.resize(getNumberOfBasisFunctions());
bf_derivs_expressions_.resize(getNumberOfBasisFunctions());
bf_values_lepton_ref_.resize(getNumberOfBasisFunctions());
bf_derivs_lepton_ref_.resize(getNumberOfBasisFunctions());
//
for(unsigned int i=1; i<getNumberOfBasisFunctions(); i++) {
std::string is; Tools::convert(i,is);
......@@ -206,6 +216,14 @@ BF_Custom::BF_Custom(const ActionOptions&ao):
plumed_merror("There was some problem in parsing the derivative of the function "+bf_str[i]+" given in FUNC"+is + " with lepton");
}
try {
bf_values_lepton_ref_[i] = &bf_values_expressions_[i].getVariableReference(variable_str_);
} catch(PLMD::lepton::Exception& exc) {}
try {
bf_derivs_lepton_ref_[i] = &bf_derivs_expressions_[i].getVariableReference(variable_str_);
} catch(PLMD::lepton::Exception& exc) {}
}
std::string transf_value_parsed;
......@@ -254,6 +272,15 @@ BF_Custom::BF_Custom(const ActionOptions&ao):
catch(PLMD::lepton::Exception& exc) {
plumed_merror("There was some problem in parsing the derivative of the function "+transf_str+" given in TRANSFORM with lepton");
}
try {
transf_value_lepton_ref_ = &transf_value_expression_.getVariableReference(transf_variable_str_);
} catch(PLMD::lepton::Exception& exc) {}
try {
transf_deriv_lepton_ref_ = &transf_deriv_expression_.getVariableReference(transf_variable_str_);
} catch(PLMD::lepton::Exception& exc) {}
}
//
log.printf(" Using the following functions [lepton parsed function and derivative]:\n");
......@@ -290,19 +317,12 @@ void BF_Custom::getAllValues(const double arg, double& argT, bool& inside_range,
double transf_derivf=1.0;
//
if(do_transf_) {
// has to copy as the function is const
lepton::CompiledExpression ce_value = transf_value_expression_;
try {
ce_value.getVariableReference(transf_variable_str_) = argT;
} catch(PLMD::lepton::Exception& exc) {}
lepton::CompiledExpression ce_deriv = transf_deriv_expression_;
try {
ce_deriv.getVariableReference(transf_variable_str_) = argT;
} catch(PLMD::lepton::Exception& exc) {}
if(transf_value_lepton_ref_) {*transf_value_lepton_ref_ = argT;}
if(transf_deriv_lepton_ref_) {*transf_deriv_lepton_ref_ = argT;}
argT = ce_value.evaluate();
transf_derivf = ce_deriv.evaluate();
argT = transf_value_expression_.evaluate();
transf_derivf = transf_deriv_expression_.evaluate();
if(check_nan_inf_ && (std::isnan(argT) || std::isinf(argT)) ) {
std::string vs; Tools::convert(argT,vs);
......@@ -318,17 +338,13 @@ void BF_Custom::getAllValues(const double arg, double& argT, bool& inside_range,
values[0]=1.0;
derivs[0]=0.0;
for(unsigned int i=1; i < getNumberOfBasisFunctions(); i++) {
lepton::CompiledExpression ce_value = bf_values_expressions_[i];
try {
ce_value.getVariableReference(variable_str_) = argT;
} catch(PLMD::lepton::Exception& exc) {}
values[i] = ce_value.evaluate();
lepton::CompiledExpression ce_deriv = bf_derivs_expressions_[i];
try {
ce_deriv.getVariableReference(variable_str_) = argT;
} catch(PLMD::lepton::Exception& exc) {}
derivs[i] = ce_deriv.evaluate();
if(bf_values_lepton_ref_[i]) {*bf_values_lepton_ref_[i] = argT;}
if(bf_derivs_lepton_ref_[i]) {*bf_derivs_lepton_ref_[i] = argT;}
values[i] = bf_values_expressions_[i].evaluate();
derivs[i] = bf_derivs_expressions_[i].evaluate();
if(do_transf_) {derivs[i]*=transf_derivf;}
// NaN checks
if(check_nan_inf_ && (std::isnan(values[i]) || std::isinf(values[i])) ) {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment