From 98a6f444f0fce0988b48fc377c64ec73bc606acc Mon Sep 17 00:00:00 2001
From: carlocamilloni <carlo.camilloni@gmail.com>
Date: Sun, 28 Oct 2018 19:26:42 +0100
Subject: [PATCH] saxs with gpu: gpu is run only by master rank

---
 regtest/isdb/rt-saxs-gpu/config |   1 +
 src/isdb/SAXS.cpp               | 137 +++++++++++++++++---------------
 2 files changed, 75 insertions(+), 63 deletions(-)

diff --git a/regtest/isdb/rt-saxs-gpu/config b/regtest/isdb/rt-saxs-gpu/config
index 8da77908e..d88b9287d 100644
--- a/regtest/isdb/rt-saxs-gpu/config
+++ b/regtest/isdb/rt-saxs-gpu/config
@@ -1,3 +1,4 @@
+mpiprocs=2
 type=driver
 arg="--plumed plumed.dat --timestep 0.005 --mf_pdb template.pdb"
 plumed_needs="arrayfire"
diff --git a/src/isdb/SAXS.cpp b/src/isdb/SAXS.cpp
index dc7ac6b66..7ae0992c1 100644
--- a/src/isdb/SAXS.cpp
+++ b/src/isdb/SAXS.cpp
@@ -358,6 +358,7 @@ SAXS::SAXS(const ActionOptions&ao):
     log<<plumed.cite("Fraser, MacRae, Suzuki, J. Appl. Crystallogr., 11, 693–694 (1978).");
     log<<plumed.cite("Brown, Fox, Maslen, O'Keefe, Willis, International Tables for Crystallography C, 554–595 (International Union of Crystallography, 2006).");
   }
+  if(bessel) log<<plumed.cite("Gumerov, Berlin, Fushman, Duraiswami, J. Comput. Chem. 33, 1981-1996 (2012).");
   log<< plumed.cite("Bonomi, Camilloni, Bioinformatics, 33, 3999 (2017)");
   log<<"\n";
 
@@ -383,77 +384,87 @@ void SAXS::calculate_gpu(vector<Vector> &deriv)
     rank   = 0;
   }
 
-  vector<float> posi;
-  posi.resize(3*size);
-  #pragma omp parallel for num_threads(OpenMP::getNumThreads())
-  for (unsigned i=0; i<size; i++) {
-    const Vector tmp = getPosition(i);
-    posi[3*i]   = static_cast<float>(tmp[0]);
-    posi[3*i+1] = static_cast<float>(tmp[1]);
-    posi[3*i+2] = static_cast<float>(tmp[2]);
-  }
+  std::vector<float> sum;
+  sum.resize(numq);
 
-  // create array a and b containing atomic coordinates
-  af::setDevice(deviceid);
-  // 3,size,1,1
-  af::array pos_a = af::array(3, size, &posi.front());
-  // size,3,1,1
-  pos_a = af::moddims(pos_a.T(), size, 1, 3);
-  // size,3,1,1
-  af::array pos_b = pos_a(af::span, af::span);
-  // size,1,3,1
-  pos_a = af::moddims(pos_a, size, 1, 3);
-  // 1,size,3,1
-  pos_b = af::moddims(pos_b, 1, size, 3);
-
-  // size,size,3,1
-  af::array xyz_dist = af::tile(pos_a, 1, size, 1) - af::tile(pos_b, size, 1, 1);
-  // size,size,1,1
-  af::array square = af::sum(xyz_dist*xyz_dist,2);
-  // size,size,1,1
-  af::array dist_sqrt = af::sqrt(square);
-  // replace the zero of square with one to avoid nan in the derivatives (the number does not matter becasue this are multiplied by zero)
-  af::replace(square,!(af::iszero(square)),1.);
-  // size,size,3,1
-  xyz_dist = xyz_dist / af::tile(square, 1, 1, 3);
-  // numq,1,1,1
-  af::array sum_device   = af::constant(0, numq, f32);
-  // numq,size,3,1
-  af::array deriv_device = af::constant(0, numq, size, 3, f32);
+  std::vector<float> dd;
+  dd.resize(size*3*numq);
 
-  for (unsigned k=0; k<numq; k++) {
-    // calculate FF matrix
-    // size,size,1,1
-    af::array FFdist_mod = af::tile(af::moddims(AFF_value(k, af::span), size, 1), 1, size)*af::tile(AFF_value(k, af::span), size, 1);
+  // on gpu only the master rank run the calculation
+  if(rank==0) {
+    vector<float> posi;
+    posi.resize(3*size);
+    #pragma omp parallel for num_threads(OpenMP::getNumThreads())
+    for (unsigned i=0; i<size; i++) {
+      const Vector tmp = getPosition(i);
+      posi[3*i]   = static_cast<float>(tmp[0]);
+      posi[3*i+1] = static_cast<float>(tmp[1]);
+      posi[3*i+2] = static_cast<float>(tmp[2]);
+    }
 
-    // get q
-    const float qvalue = static_cast<float>(q_list[k]);
-    // size,size,1,1
-    af::array dist_q = qvalue*dist_sqrt;
-    // size,size,1
-    af::array dist_sin = af::sin(dist_q)/dist_q;
-    af::replace(dist_sin,!(af::isNaN(dist_sin)),1.);
-    // 1,1,1,1
-    sum_device(k) = af::sum(af::flat(dist_sin)*af::flat(FFdist_mod));
+    // create array a and b containing atomic coordinates
+    af::setDevice(deviceid);
+    // 3,size,1,1
+    af::array pos_a = af::array(3, size, &posi.front());
+    // size,3,1,1
+    pos_a = af::moddims(pos_a.T(), size, 1, 3);
+    // size,3,1,1
+    af::array pos_b = pos_a(af::span, af::span);
+    // size,1,3,1
+    pos_a = af::moddims(pos_a, size, 1, 3);
+    // 1,size,3,1
+    pos_b = af::moddims(pos_b, 1, size, 3);
 
+    // size,size,3,1
+    af::array xyz_dist = af::tile(pos_a, 1, size, 1) - af::tile(pos_b, size, 1, 1);
     // size,size,1,1
-    af::array tmp = FFdist_mod*(dist_sin - af::cos(dist_q));
+    af::array square = af::sum(xyz_dist*xyz_dist,2);
+    // size,size,1,1
+    af::array dist_sqrt = af::sqrt(square);
+    // replace the zero of square with one to avoid nan in the derivatives (the number does not matter becasue this are multiplied by zero)
+    af::replace(square,!(af::iszero(square)),1.);
     // size,size,3,1
-    af::array dd_all = af::tile(tmp, 1, 1, 3)*xyz_dist;
-    // it should become 1,size,3
-    deriv_device(k, af::span, af::span) = af::sum(dd_all,0);
-  }
+    xyz_dist = xyz_dist / af::tile(square, 1, 1, 3);
+    // numq,1,1,1
+    af::array sum_device   = af::constant(0, numq, f32);
+    // numq,size,3,1
+    af::array deriv_device = af::constant(0, numq, size, 3, f32);
+
+    for (unsigned k=0; k<numq; k++) {
+      // calculate FF matrix
+      // size,size,1,1
+      af::array FFdist_mod = af::tile(af::moddims(AFF_value(k, af::span), size, 1), 1, size)*af::tile(AFF_value(k, af::span), size, 1);
+
+      // get q
+      const float qvalue = static_cast<float>(q_list[k]);
+      // size,size,1,1
+      af::array dist_q = qvalue*dist_sqrt;
+      // size,size,1
+      af::array dist_sin = af::sin(dist_q)/dist_q;
+      af::replace(dist_sin,!(af::isNaN(dist_sin)),1.);
+      // 1,1,1,1
+      sum_device(k) = af::sum(af::flat(dist_sin)*af::flat(FFdist_mod));
+
+      // size,size,1,1
+      af::array tmp = FFdist_mod*(dist_sin - af::cos(dist_q));
+      // size,size,3,1
+      af::array dd_all = af::tile(tmp, 1, 1, 3)*xyz_dist;
+      // it should become 1,size,3
+      deriv_device(k, af::span, af::span) = af::sum(dd_all,0);
+    }
 
-  // read out results
-  std::vector<float> sum;
-  sum.resize(numq);
-  sum_device.host(&sum.front());
+    // read out results
+    sum_device.host(&sum.front());
 
-  std::vector<float> dd;
-  dd.resize(size*3*numq);
-  deriv_device = af::reorder(deriv_device, 2, 1, 0);
-  deriv_device = af::flat(deriv_device);
-  deriv_device.host(&dd.front());
+    deriv_device = af::reorder(deriv_device, 2, 1, 0);
+    deriv_device = af::flat(deriv_device);
+    deriv_device.host(&dd.front());
+  }
+
+  if(!serial) {
+    comm.Bcast(deriv, 0);
+    comm.Bcast(sum, 0);
+  }
 
   for(unsigned k=0; k<numq; k++) {
     string num; Tools::convert(k,num);
-- 
GitLab