diff --git a/src/dimred/SketchMapBase.cpp b/src/dimred/SketchMapBase.cpp index f4f2db7ec1af631f34eb0db5e551332abf0971d3..e1e5814635028cefca7896a0e3e1f700f7544115 100644 --- a/src/dimred/SketchMapBase.cpp +++ b/src/dimred/SketchMapBase.cpp @@ -75,11 +75,17 @@ smapbase(NULL) void SketchMapBase::calculateProjections( const Matrix<double>& targets, Matrix<double>& projections ){ if( dtargets.size()!=targets.nrows() ){ // These hold data so that we can do stress calculations - dtargets.resize( targets.nrows() ); ftargets.resize( targets.nrows() ); + dtargets.resize( targets.nrows() ); ftargets.resize( targets.nrows() ); pweights.resize( targets.nrows() ); // Matrices for storing input data transformed.resize( targets.nrows(), targets.ncols() ); - distances.resize( targets.nrows(), targets.ncols() ); + distances.resize( targets.nrows(), targets.ncols() ); } + + // Stores the weights in an array for faster access, as well as the normalization + normw=0; + for(unsigned i=0;i<targets.nrows() ;++i) { pweights[i] = getWeight(i); normw+=pweights[i]; } + normw*=normw; + // Transform the high dimensional distances double df; distances=0.; transformed=0.; for(unsigned i=1;i<distances.ncols();++i){ @@ -123,16 +129,12 @@ double SketchMapBase::calculateStress( const std::vector<double>& p, std::vector double SketchMapBase::calculateFullStress( const std::vector<double>& p, std::vector<double>& d ){ // Zero derivative and stress accumulators for(unsigned i=0;i<p.size();++i) d[i]=0.0; - double stress=0; std::vector<double> dtmp( p.size() ); - - // Compute normalization for weights - double normw = 0; - for(unsigned i=1;i<distances.nrows();++i){ - for(unsigned j=0;j<i;++j) normw += getWeight(i)*getWeight(j); - } + double stress=0; std::vector<double> dtmp( nlow ); for(unsigned i=1;i<distances.nrows();++i){ + double iweight = pweights[i]; for(unsigned j=0;j<i;++j){ + double jweight = pweights[j]; // Calculate distance in low dimensional space double dd=0; for(unsigned k=0;k<nlow;++k){ dtmp[k]=p[nlow*i+k] - p[nlow*j+k]; dd+=dtmp[k]*dtmp[k]; } @@ -144,16 +146,16 @@ double SketchMapBase::calculateFullStress( const std::vector<double>& p, std::ve double fdiff = fd - transformed(i,j);; // Calculate derivatives - double pref = 2.*getWeight(i)*getWeight(j) / (normw*dd); - for(unsigned k=0;k<p.size();++k){ - d[nlow*i+k] += pref*( (1-mixparam)*fdiff*df + mixparam*ddiff )*dtmp[k]; - d[nlow*j+k] -= pref*( (1-mixparam)*fdiff*df + mixparam*ddiff )*dtmp[k]; + double pref = 2.*iweight*jweight*( (1-mixparam)*fdiff*df + mixparam*ddiff ) / dd; + for(unsigned k=0;k<nlow;++k){ + double dterm=pref*dtmp[k]; d[nlow*i+k]+=dterm; d[nlow*j+k]-=dterm; } // Accumulate the total stress - stress += getWeight(i)*getWeight(j)*( (1-mixparam)*fdiff*fdiff + mixparam*ddiff*ddiff ) / normw; + stress += iweight*jweight*( (1-mixparam)*fdiff*fdiff + mixparam*ddiff*ddiff ); } } + stress /= normw; for (unsigned k=0; k < d.size(); ++k) d[k] /= normw; return stress; } diff --git a/src/dimred/SketchMapBase.h b/src/dimred/SketchMapBase.h index 374004b2756d3791f64ce4b40db15306121e9b4d..68f1f1bdb7429805524f275df395ba177a0d85db 100644 --- a/src/dimred/SketchMapBase.h +++ b/src/dimred/SketchMapBase.h @@ -39,7 +39,9 @@ private: SwitchingFunction lowdf, highdf; /// This is used within calculate stress to hold the target distances and the /// target values for the high dimensional switching function - std::vector<double> dtargets, ftargets; + std::vector<double> dtargets, ftargets, pweights; +/// Stress normalization (sum_ij w_i w_j) + double normw; protected: /// This holds the target distances and target transformed distances Matrix<double> distances, transformed; diff --git a/src/tools/SwitchingFunction.cpp b/src/tools/SwitchingFunction.cpp index 94d4299ab75ffed7ab79350f9345e76eb9e90152..cd95d213f4601be754a45ab4666df7a022ab310f 100644 --- a/src/tools/SwitchingFunction.cpp +++ b/src/tools/SwitchingFunction.cpp @@ -286,7 +286,7 @@ double SwitchingFunction::calculate(double distance,double&dfunc)const{ dfunc=0.0; }else{ if(type==smap){ - double sx=c*pow( rdist, a ); + double sx=c*Tools::fastpow( rdist, a ); result=pow( 1.0 + sx, d ); dfunc=-b*sx/rdist*result/(1.0+sx); } else if(type==rational){