package cz.fidentis.analyst.tests;

import cz.fidentis.analyst.visitors.kdtree.AvgFaceConstructor;
import cz.fidentis.analyst.face.HumanFace;
import cz.fidentis.analyst.face.HumanFaceFactory;
import cz.fidentis.analyst.icp.IcpTransformer;
import cz.fidentis.analyst.icp.NoUndersampling;
import cz.fidentis.analyst.mesh.io.MeshObjExporter;
import cz.fidentis.analyst.visitors.mesh.HausdorffDistance;
import cz.fidentis.analyst.visitors.mesh.HausdorffDistance.Strategy;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileWriter;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.time.Duration;
import java.util.List;
import java.util.stream.Collectors;

/**
 * A class for testing the efficiency of batch processing algorithms.
 * It works in the same way as {@link BatchSimilarityGroundTruth} with the following changes:
 * <ul>
 * <li>All faces are transformed (ICP) to the very first face of the collection. They are not transformed mutually.
 *   It enables us to compute ICP only once for every face, but with possible loss of precision.
 *   Moreover, {@code HumanFaceFactory} is used to quickly load/swap faces when used multiple times.</li>
 * </ul>
 * Stats for 100 faces WITH CROP:
 * <pre>
 * Time of AVG face computation time: 00:00:19,529
 * ICP computation time: 00:05:19,096
 * HD computation time: 03:17:29,446
 * Total computation time: 03:32:30,671
 * </pre>
 * Stats for 100 faces WITHOUT CROP:
 * <pre>
 * Time of AVG face computation time: 00:00:19,386
 * ICP computation time: 00:05:39,318
 * HD computation time: 03:11:49,226
 * Total computation time: 03:25:50,957
 * </pre>
 * 
 * @author Radek Oslejsek
 */
public class BatchSimilarityGroundTruthOpt {
    
    private static final String DATA_DIR = "../../analyst-data-antropologie/_ECA";
    private static final String OUTPUT_FILE = "../../SIMILARITY_GROUND_TRUTH_OPT.csv";
    private static final String TEMPLATE_FACE_PATH = "../../analyst-data-antropologie/template_face.obj";
    private static final int MAX_SAMPLES = 100;
    private static final boolean CROP_HD = false;
    
    /**
     * Main method 
     * @param args Input arguments 
     * @throws IOException on IO error
     */
    public static void main(String[] args) throws IOException, ClassNotFoundException, Exception {
        List<Path> faces = Files.list(new File(DATA_DIR).toPath())
                .filter(f -> f.toString().endsWith(".obj"))
                .sorted()
                .limit(MAX_SAMPLES)
                .collect(Collectors.toList());
        
        double[][] distances = new double[faces.size()][faces.size()];
        String[] names = new String[faces.size()];
        
        long totalTime = System.currentTimeMillis();
        long avgFaceComputationTime = 0;
        long icpComputationTime = 0;
        long hdComputationTime = 0;
        
        AvgFaceConstructor avgFaceConstructor = new AvgFaceConstructor(new HumanFace(faces.get(0).toFile()).getMeshModel());
        
        HumanFaceFactory factory = new HumanFaceFactory();
        factory.setReuseDumpFile(true);
        factory.setStrategy(HumanFaceFactory.Strategy.MRU);
        
        int counter = 0;
        for (int i = 0; i < faces.size(); i++) {
            String priFaceId = factory.loadFace(faces.get(i).toFile());
            HumanFace priFace = factory.getFace(priFaceId);
            names[i] = priFace.getShortName().replaceAll("_01_ECA", "");
            
            for (int j = i; j < faces.size(); j++) { // starts with "i"!
                priFace = factory.getFace(priFaceId); // re-read if dumped, at leat update the access time
                
                String secFaceId = factory.loadFace(faces.get(j).toFile());
                HumanFace secFace = factory.getFace(secFaceId);
                
                System.out.print(++counter + ": " + priFace.getShortName() + " - " + secFace.getShortName());
                
                // transform secondary face, but only once. Transformed faces are stored in the HumanFaceFactory! Don't transform the same face
                if (i == 0 && i != j) { 
                    System.out.print(", ICP");
                    long icpTime = System.currentTimeMillis();
                    IcpTransformer icp = new IcpTransformer(priFace.getMeshModel(), 100, true, 0.3, new NoUndersampling());
                    secFace.getMeshModel().compute(icp, true);
                    icpComputationTime += System.currentTimeMillis() - icpTime;
                }
                
                long hdTime = System.currentTimeMillis();
                // compute HD from secondary to primary:
                priFace.computeKdTree(true);
                HausdorffDistance hd = new HausdorffDistance(priFace.getKdTree(), Strategy.POINT_TO_POINT, false, true, CROP_HD);
                secFace.getMeshModel().compute(hd, true);
                distances[j][i] = hd.getStats().getAverage();
                // compute HD from primary to secondary:
                secFace.computeKdTree(true);
                hd = new HausdorffDistance(secFace.getKdTree(), Strategy.POINT_TO_POINT, false, true, CROP_HD);
                priFace.getMeshModel().compute(hd, true);
                distances[i][j] = hd.getStats().getAverage();
                hdComputationTime += System.currentTimeMillis() - hdTime;
                
                // Compute AVG face. Use each tranfromed face only once. Skip the very first face
                if (i == 0 && j != 0) { 
                    System.out.print(", AVG");
                    long avgFaceTime = System.currentTimeMillis();
                    priFace.getKdTree().accept(avgFaceConstructor);
                    avgFaceComputationTime += System.currentTimeMillis() - avgFaceTime;
                }
                
                System.out.println(", " + factory.toString());
            }
            
            factory.removeFace(priFaceId); // the face is no longer needed
        }
        
        MeshObjExporter exp = new MeshObjExporter(avgFaceConstructor.getAveragedMeshModel());
        exp.exportModelToObj(new File(TEMPLATE_FACE_PATH));
        
        BufferedWriter w = new BufferedWriter(new FileWriter(OUTPUT_FILE));
        w.write("PRI FACE;SEC FACE;AVG HD from PRI to SEC;AVG HD from SEC to PRI");
        w.newLine();
        for (int i = 0; i < faces.size(); i++) {
            for (int j = i; j < faces.size(); j++) {
                w.write(names[i] + ";");
                w.write(names[j] + ";");
                w.write(String.format("%.8f", distances[i][j]) + ";");
                w.write(String.format("%.8f", distances[j][i]) + ";");
                if (distances[i][j] > distances[j][i]) {
                    w.write(String.format("%.8f", distances[i][j]) + ";");
                    w.write(String.format("%.8f", distances[j][i]) + "");
                } else {
                    w.write(String.format("%.8f", distances[j][i]) + ";");
                    w.write(String.format("%.8f", distances[i][j]) + "");
                }
                w.newLine();
            }
        }
        w.close();
        
        System.out.println();
        Duration duration = Duration.ofMillis(avgFaceComputationTime);
        System.out.println("Time of AVG face computation time: " + 
                String.format("%02d:%02d:%02d,%03d", duration.toHoursPart(), duration.toMinutesPart(), duration.toSecondsPart(), duration.toMillisPart()));
        duration = Duration.ofMillis(icpComputationTime);
        System.out.println("ICP computation time: " + 
                String.format("%02d:%02d:%02d,%03d", duration.toHoursPart(), duration.toMinutesPart(), duration.toSecondsPart(), duration.toMillisPart()));
        duration = Duration.ofMillis(hdComputationTime);
        System.out.println("HD computation time: " + 
                String.format("%02d:%02d:%02d,%03d", duration.toHoursPart(), duration.toMinutesPart(), duration.toSecondsPart(), duration.toMillisPart()));
        duration = Duration.ofMillis(System.currentTimeMillis() - totalTime);
        System.out.println("Total computation time: " + 
                String.format("%02d:%02d:%02d,%03d", duration.toHoursPart(), duration.toMinutesPart(), duration.toSecondsPart(), duration.toMillisPart()));
    }

}
