package cz.fidentis.analyst.visitors.kdtree;

import cz.fidentis.analyst.kdtree.KdTree;
import cz.fidentis.analyst.kdtree.KdTreeVisitor;
import cz.fidentis.analyst.mesh.core.MeshFacet;
import cz.fidentis.analyst.mesh.core.MeshFacetImpl;
import cz.fidentis.analyst.mesh.core.MeshModel;
import cz.fidentis.analyst.visitors.mesh.HausdorffDistance;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.stream.Stream;
import javax.vecmath.Point3d;
import javax.vecmath.Vector3d;

/**
 * Visitor capable to create an average face by deforming given face template
 * so that its vertices are in the "average position" with respect to the visited faces.
 * The average position is computed as the centroid of mass of the closest points
 * from inspected faces.
 * <strong>
 * It is supposed that the inspected faces are already registered 
 * (superimposed with the template face and with each other).
 * </strong>
 * 
 * @author Radek Oslejsek
 */
public class AvgFaceConstructor extends KdTreeVisitor {
    
    private MeshModel avgMeshModel = null;
    
    private int numInspectedFacets = 0;
    
    private final Map<MeshFacet, List<Vector3d>> transformations = new HashMap<>();
    
    /**
     * Constructor.
     * 
     * @param templateFacet Mesh facet which is transformed to the averaged mesh. 
     *        The original mesh remains unchanged. New mesh is allocated instead.
     * @throws IllegalArgumentException if some parameter is wrong
     */
    public AvgFaceConstructor(MeshFacet templateFacet) {
        this(new HashSet<>(Collections.singleton(templateFacet)));
        if (templateFacet == null) {
            throw new IllegalArgumentException("templateFacet");
        }
    }
    
    /**
     * Constructor.
     * 
     * @param templateFacets Mesh facets that are transformed to the averaged mesh
     *        The original mesh remains unchanged. New mesh is allocated instead.
     * @throws IllegalArgumentException if some parameter is wrong
     */
    public AvgFaceConstructor(Set<MeshFacet> templateFacets) {
        if (templateFacets == null || templateFacets.isEmpty()) {
            throw new IllegalArgumentException("templateFacets");
        }
        
        // fill vertTarns with empty list for each facet
        templateFacets.parallelStream().forEach(f -> {
                //transformations.put(f, new ArrayList<>(f.getVertices().size()));
                transformations.put(f, Stream.generate(Vector3d::new)
                        .limit(f.getVertices().size())
                        .collect(Collectors.toList())
                );
        });
    }
    
    /**
     * Constructor.
     * 
     * @param templateModel Mesh model which is transformed to the averaged mesh
     *        The original mesh remains unchanged. New mesh is allocated instead.
     * @throws IllegalArgumentException if some parameter is wrong
     */
    public AvgFaceConstructor(MeshModel templateModel) {
        this(new HashSet<>(templateModel.getFacets()));
    }
    
    @Override
    public void visitKdTree(KdTree kdTree) {
        avgMeshModel = null; // AVG mesh model will be re-computed in the getAveragedMeshModel()
        numInspectedFacets++;
        
        // compute HD from me to the mesh stored in the k-d tree:
        HausdorffDistance hDist = new HausdorffDistance(
                kdTree, 
                HausdorffDistance.Strategy.POINT_TO_POINT, 
                false, // relative distance
                true,  // parallel computation
                false  // auto crop
        );
        transformations.keySet().forEach(f -> hDist.visitMeshFacet(f)); 
        
        // compute shifts of my vertices
        for (MeshFacet myFacet: transformations.keySet()) {
            List<Point3d> closestPoints = hDist.getNearestPoints().get(myFacet);
            
            // shift vertices concurrently
            IntStream.range(0, closestPoints.size()).parallel().forEach(i -> {
                Vector3d moveDir = new Vector3d(closestPoints.get(i));
                moveDir.sub(myFacet.getVertex(i).getPosition());
                transformations.get(myFacet).get(i).add(moveDir);
            });
        }
    }
    
    /**
     * Computes and returns the averaged human face.
     * @return averaged human face
     */
    public MeshModel getAveragedMeshModel() {
        if (avgMeshModel == null) {
            avgMeshModel = new MeshModel();
            for (MeshFacet f: transformations.keySet()) { // clone all facets of the template face
                MeshFacet newFacet = new MeshFacetImpl(f);
                IntStream.range(0, newFacet.getNumberOfVertices()).parallel().forEach(i -> {
                    Vector3d tr = new Vector3d(transformations.get(f).get(i));
                    tr.scale(1.0/numInspectedFacets);
                    newFacet.getVertex(i).getPosition().add(tr);
                });
                avgMeshModel.addFacet(newFacet);
            }
        }
        return avgMeshModel;
    }
    
}
