package cz.fidentis.analyst.batch;

import cz.fidentis.analyst.Logger;
import cz.fidentis.analyst.core.ProgressDialog;
import cz.fidentis.analyst.face.HumanFace;
import cz.fidentis.analyst.face.HumanFaceFactory;
import cz.fidentis.analyst.visitors.mesh.HausdorffDistance;
import java.nio.file.Path;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.logging.Level;
import java.util.stream.Collectors;
import java.util.stream.DoubleStream;

/**
 * A task that computes similarity of a set of faces by computing
 * the distance of faces to an average face and then combining these values
 * to get mutual similarity for all pairs.
 * <p>
 * The computation is accelerated by using multiple CPU cores concurrently.
 * The exact computation parameters are taken from the {@code BatchPanel}.
 * </p>
  * 
 * @author Radek Oslejsek
 */
public class ApproxHausdorffDistTask extends SimilarityTask {
    
    private final Stopwatch totalTime = new Stopwatch("Total computation time:\t\t");
    private final Stopwatch hdComputationTime = new Stopwatch("Hausdorff distance preparation time:\t");
    private final Stopwatch loadTime = new Stopwatch("Disk access time:\t\t");
    private final Stopwatch kdTreeConstructionTime = new Stopwatch("KD trees construction time:\t\t");
    private final Stopwatch finalHdComputationTime = new Stopwatch("Hausdorff distance finalization time:\t");
    
    private final int templateFaceIndex;
    
    /**
     * Constructor.
     * 
     * @param progressDialog A window that show the progress of the computation. Must not be {@code null}
     * @param controlPanel A control panel with computation parameters. Must not be {@code null}
     * @param avgFace average face
     */
    public ApproxHausdorffDistTask(ProgressDialog progressDialog, BatchPanel controlPanel, int templateFaceIndex) {
        super(progressDialog, controlPanel);
        this.templateFaceIndex = templateFaceIndex;
    }
    
    @Override
    protected Void doInBackground() throws Exception {
        HumanFaceFactory factory = getControlPanel().getHumanFaceFactory();
        List<Path> faces = getControlPanel().getFacePaths();
        
        totalTime.start();
        
        factory.setReuseDumpFile(true); // it's safe because no changes are made to models 
        factory.setStrategy(HumanFaceFactory.Strategy.MRU); // keep first X faces in the memory
        
        // We don't need to reaload the initFace periodically for two reasons:
        //   - It is never dumped from memory to disk because we use MRU
        //   - Even if dumped, the face keeps in the mempry until we hold the pointer to it
        loadTime.start();
        String templateFaceId = factory.loadFace(faces.get(templateFaceIndex).toFile());
        HumanFace templateFace = factory.getFace(templateFaceId);
        loadTime.stop();
        
        // Cache of distances of individual vertices
        //   distChache.get(i) = i-th face
        //   distChache.get(i).get(j) = distance of j-th vertex of i-th face
        List<List<Double>> distCache = new ArrayList<>();
        
        for (int i = 0; i < faces.size(); i++) {
            
            if (isCancelled()) { // the user canceled the process
                return null;
            }
            
            if (i != templateFaceIndex) {
                
                loadTime.start();
                String faceId = factory.loadFace(faces.get(i).toFile());
                HumanFace face = factory.getFace(faceId);
                loadTime.stop();
                
                kdTreeConstructionTime.start();
                face.computeKdTree(false);
                kdTreeConstructionTime.stop();

                hdComputationTime.start();
                HausdorffDistance hd = new HausdorffDistance(
                        face.getKdTree(),
                        HausdorffDistance.Strategy.POINT_TO_POINT,
                        true, // relative
                        true, // parallel
                        true  // crop
                );
                templateFace.getMeshModel().compute(hd, true);
                
                // Store relative distances of individual vertices to the cache                
                distCache.add(hd.getDistances()
                        .values()
                        .stream()
                        .flatMap(List::stream)
                        .collect(Collectors.toList()));
                face.removeKdTree(); // TO BE TESTED
                
                hdComputationTime.stop();
            } else {
                distCache.add(DoubleStream.generate(() -> 0.0d)
                        .limit(templateFace.getMeshModel().getNumVertices())
                        .boxed()
                        .collect(Collectors.toList())
                );
            }
            
            // update progress bar
            int progress = (int) Math.round(100.0 * (i+1) / faces.size());
            getProgressDialog().setValue(progress);
            
            //Logger.print(factory.toString());
        }
        
        finalHdComputationTime.start();
        //finalizeHD(distCache, templateFace.getMeshModel().getNumVertices());
        finalizeHdConcurrently(distCache, templateFace.getMeshModel().getNumVertices());
        finalHdComputationTime.stop();
        
        totalTime.stop();

        printTimeStats();

        return null;
    }
    
    protected void printTimeStats() {
        Logger.print(hdComputationTime.toString());
        Logger.print(loadTime.toString());
        Logger.print(finalHdComputationTime.toString());
        Logger.print(kdTreeConstructionTime.toString());
        Logger.print(totalTime.toString());
    }
    
    @Deprecated
    protected void finalizeHD(List<List<Double>> distCache, long numVertices) {
        for (int i = 0; i < distCache.size(); i++) {
            for (int j = i; j < distCache.size(); j++) {
                assert(distCache.get(i) == null || distCache.get(j) == null 
                        || distCache.get(i).size() == distCache.get(j).size());
                double sum = 0.0;
                int counter = 0;
                for (int k = 0; k < numVertices; k++) {
                    double d1 = distCache.get(i).get(k);
                    double d2 = distCache.get(j).get(k);
                    if (Double.isFinite(d1) && Double.isFinite(d2)) {
                        sum += Math.abs(d1 - d2);
                        counter++;
                    }
                }
                setDistSimilarity(i, j, sum / counter);
                setDistSimilarity(j, i, sum / counter);
            }
            
            // update progress bar
            int progress = (int) Math.round(100.0 * (i+1) / distCache.size());
            getProgressDialog().setValue(progress);
        }
    }
    
    protected void finalizeHdConcurrently(List<List<Double>> distCache, long numVertices) {
        ExecutorService executor = Executors.newFixedThreadPool(Runtime.getRuntime().availableProcessors());
        final List<Future<Void>> results = new ArrayList<>();
        
        for (int i = 0; i < distCache.size(); i++) {
            for (int j = i; j < distCache.size(); j++) {
                assert(distCache.get(i).size() == distCache.get(i).size()); // should never happen
                
                final int fi = i;
                final int fj = j;
                
                // Compute average HD for the pair of faces in a separate threat:
                results.add(executor.submit(new Callable<Void>() {
                    @Override
                    public Void call() throws Exception {
                        double sum = 0.0;
                        int counter = 0;
                        for (int k = 0; k < distCache.get(fi).size(); k++) {
                            double d1 = distCache.get(fi).get(k);
                            double d2 = distCache.get(fj).get(k);
                            if (Double.isFinite(d1) && Double.isFinite(d2)) {
                                sum += Math.abs(d1 - d2);
                                counter++;
                            }
                        }
                        setDistSimilarity(fi, fj, sum / counter);
                        setDistSimilarity(fj, fi, sum / counter);
                        return null;
                    }
                }));
            }
        }
        
        executor.shutdown();
        while (!executor.isTerminated()){}
        try {
            int i = 0;
            for (Future<Void> res: results) {
                res.get(); // waits until all computations are finished
                int progress = (int) Math.round(100.0 * (i+1) / results.size());
                getProgressDialog().setValue(progress);
            }
        } catch (final InterruptedException | ExecutionException ex) {
            java.util.logging.Logger.getLogger(ApproxHausdorffDistTask.class.getName()).log(Level.SEVERE, null, ex);
        }
    }
    
}
