package cz.fidentis.analyst.icp;

import cz.fidentis.analyst.kdtree.KdTree;
import cz.fidentis.analyst.mesh.MeshVisitor;
import cz.fidentis.analyst.mesh.core.MeshFacet;
import cz.fidentis.analyst.mesh.core.MeshModel;
import cz.fidentis.analyst.mesh.core.MeshPoint;
import cz.fidentis.analyst.visitors.mesh.HausdorffDistance;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import javax.vecmath.Matrix4d;
import javax.vecmath.Point3d;
import javax.vecmath.Vector3d;

/**
 * Visitor applying the Iterative Closest Point (ICP) algorithm to minimize 
 * Hausdorff distance of two triangular meshes. 
 * Inspected mesh facets are transformed (!) (vertices are moved) and the history 
 * of their transformations (transformation performed in each ICP iteration) is returned.
 * <p>
 * This visitor <strong>is not thread-safe</strong> due to efficiency reasons
 * and because of the algorithm principle, when it iterates multiple times through 
 * each inspected facet. Therefore, concurrent ICP transformation has to use 
 * an individual ICP visitor for each transformed (inspected) mesh facet.
 * Sequential inspection of multiple facets using a single visitor's instance 
 * is possible.
 * </p>
 * 
 * @author Maria Kocurekova
 * @author Radek Oslejsek
 */
public class IcpTransformer extends MeshVisitor   {
    
    /**
     * Transformed mesh facets and their history of transformations.
     * Key = visited mesh facets. 
     * Value = History of transformations (transformation performed in each ICP iteration).
     */
    private final Map<MeshFacet, List<IcpTransformation>> transformations = new HashMap<>();
    
    /**
     * Maximal number of ICP iterations
     */
    private final int maxIteration;
    
    /**
     * Whether to scale mesh facets as well
     */
    private final boolean scale;
    
    /**
     * Acceptable error. When reached, then the ICP stops.
     */
    private final double error;
    
    /**
     * K-d tree of the primary triangular mesh(es).
     */
    private final KdTree primaryKdTree;
    
    private final UndersamplingStrategy reductionStrategy;
    
    /**
     * Constructor.
     * 
     * @param mainFacet Primary mesh facet. Must not be {@code null}. 
     * Inspected facets are transformed toward this primary mesh.
     * @param maxIteration Maximal number of ICP iterations (it includes computing 
     * new transformation and applying it). A number bigger than zero.
     * Reasonable number seems to be 10.
     * @param scale If {@code true}, then the scale factor is also computed.
     * @param error Acceptable error. A number bugger than or equal to zero. 
     * When reached, then the ICP stops. Reasonable number seems to be 0.05.
     * @param strategy One of the reduction strategies. If {@code null}, then {@link NoUndersampling} is used.
     * @throws IllegalArgumentException if some parameter is wrong
     */
    public IcpTransformer(MeshFacet mainFacet, int maxIteration, boolean scale, double error, UndersamplingStrategy strategy) {
        this(new HashSet<>(Collections.singleton(mainFacet)), maxIteration, scale, error, strategy);
        if (mainFacet == null) {
            throw new IllegalArgumentException("mainFacet");
        }
    }

    /**
     * Constructor.
     * 
     * @param mainFacets Primary mesh facets. Must not be {@code null}. 
     * Inspected facets are transformed toward these primary mesh. 
     * @param maxIteration Maximal number of ICP iterations (it includes computing 
     * new transformation and applying it). A number bigger than zero.
     * Reasonable number seems to be 10.
     * @param scale If {@code true}, then the scale factor is also computed.
     * @param error Acceptable error. A number bugger than or equal to zero. 
     * When reached, then the ICP stops. Reasonable number seems to be 0.05.
     * @param strategy One of the reduction strategies. If {@code null}, then {@link NoUndersampling} is used.
     * @throws IllegalArgumentException if some parameter is wrong
     */
    public IcpTransformer(Set<MeshFacet> mainFacets, int maxIteration, boolean scale, double error, UndersamplingStrategy strategy) {
        if (mainFacets == null) {
            throw new IllegalArgumentException("mainFacets");
        }
        if (maxIteration <= 0) {
            throw new IllegalArgumentException("maxIteration");
        }
        if (error < 0.0) {
            throw new IllegalArgumentException("error");
        }
        this.primaryKdTree = new KdTree(new ArrayList<>(mainFacets));
        this.error = error;
        this.maxIteration = maxIteration;
        this.scale = scale;
        this.reductionStrategy = (strategy == null) ? new NoUndersampling() : strategy;
    }

    /**
     * Constructor.
     * 
     * @param mainModel Primary mesh model. Must not be {@code null}. 
     * Inspected facets are transformed toward this primary mesh.
     * @param maxIteration Maximal number of ICP iterations (it includes computing 
     * new transformation and applying it). A number bigger than zero.
     * Reasonable number seems to be 10.
     * @param scale If {@code true}, then the scale factor is also computed.
     * @param error Acceptable error. A number bugger than or equal to zero. 
     * When reached, then the ICP stops. Reasonable number seems to be 0.05.
     * @param strategy One of the reduction strategies. If {@code null}, then {@link NoUndersampling} is used.
     * @throws IllegalArgumentException if some parameter is wrong
     */
    public IcpTransformer(MeshModel mainModel, int maxIteration, boolean scale, double error, UndersamplingStrategy strategy) {
        this(new HashSet<>(mainModel.getFacets()), maxIteration, scale, error, strategy);
        if (mainModel.getFacets().isEmpty()) {
            throw new IllegalArgumentException("mainModel");
        }
    }
    
    /**
     * Constructor.
     * 
     * @param primaryKdTree The k-d tree of the primary mesh. Must not be {@code null}. 
     * Inspected facets are transformed toward this primary mesh.
     * @param maxIteration Maximal number of ICP iterations (it includes computing 
     * new transformation and applying it). A number bigger than zero.
     * Reasonable number seems to be 10.
     * @param scale If {@code true}, then the scale factor is also computed.
     * @param error Acceptable error. A number bugger than or equal to zero. When reached, then the ICP stops.
     * @param strategy One of the reduction strategies. If {@code null}, then {@link NoUndersampling} is used.
     * @throws IllegalArgumentException if some parameter is wrong
     */
    public IcpTransformer(KdTree primaryKdTree, int maxIteration, boolean scale, double error, UndersamplingStrategy strategy) {
        if (primaryKdTree == null) {
            throw new IllegalArgumentException("primaryKdTree");
        }
        if (maxIteration <= 0) {
            throw new IllegalArgumentException("maxIteration");
        }
        if (error < 0.0) {
            throw new IllegalArgumentException("error");
        }
        this.primaryKdTree = primaryKdTree;
        this.error = error;
        this.maxIteration = maxIteration;
        this.scale = scale;
        this.reductionStrategy = (strategy == null) ? new NoUndersampling() : strategy;
    }
    
    /**
     * Returns the history of transformations (transformation performed in each 
     * ICP iteration for each inspected mesh facet).
     * Keys in the map contain mesh facets that were inspected and transformed. 
     * For each transformed facet, a list of transformations to the primary mesh
     * is stored. The order of transformations corresponds to the order of ICP iterations, 
     * i.e., the i-th value is the transformation applied in the i-th iteration on the visited mesh facet.
     * It also means that the size of the list corresponds to the number of iterations 
     * performed for given mesh facet.
     * 
     * @return The history of transformations (transformation performed in each ICP iteration for each inspected mesh facet).
     */
    public Map<MeshFacet, List<IcpTransformation>> getTransformations() {
        return Collections.unmodifiableMap(transformations);
    }
    
    
    /**
     * This visitor is <strong>is not thread-safe</strong>.
     * Therefore, concurrent ICP transformations have to use an individual ICP visitor for each 
     * transformed (inspected) mesh facet.
     * 
     * @return {@code false}.
     */
    @Override
    public boolean isThreadSafe() {
        return false;
    }
    
    /**
     * Returns the maximal number of ICP iterations used for the computation.
     * 
     * @return the maximal number of ICP iterations used for the computation
     */
    public int getMaxIterations() {
        return this.maxIteration;
    }
    
    /**
     * Returns maximal acceptable error used for the computation.
     * 
     * @return maximal acceptable error used for the computation
     */
    public double getError() {
        return this.error;
    }
    
    /**
     * Returns {@code true} if the inspected mesh faces were also scaled.
     * 
     * @return {@code true} if the inspected mesh faces were also scaled.
     */
    public boolean getScale() {
        return this.scale;
    }
    
    /**
     * Returns k-d tree of the primary mesh to which other meshes has been transformed.
     * 
     * @return k-d tree of the primary mesh
     */
    public KdTree getPrimaryKdTree() {
        return this.primaryKdTree;
    }
    
    @Override
    public void visitMeshFacet(MeshFacet transformedFacet) {
        HausdorffDistance hausdorffDist = new HausdorffDistance(
                this.primaryKdTree, 
                HausdorffDistance.Strategy.POINT_TO_POINT, 
                false, // relative distance
                true   // parallel computation
        );
        
        MeshFacet reducedFacet = new UndersampledMeshFacet(transformedFacet, reductionStrategy);

        int currentIteration = 0;
        IcpTransformation transformation = null;
        
        if (!transformations.containsKey(transformedFacet)) {
            transformations.put(transformedFacet, new ArrayList<>(maxIteration));
        }

        double prevMeanD = Double.POSITIVE_INFINITY;

        while ((currentIteration < maxIteration) &&
                (Double.isInfinite(prevMeanD) || Math.abs(prevMeanD - Objects.requireNonNull(transformation).getMeanD()) > error )) {

            hausdorffDist.visitMeshFacet(reducedFacet); // repeated incpection re-calculates the distance
            List<Double> distances = hausdorffDist.getDistances().get(reducedFacet);
            List<Point3d> nearestPoints = hausdorffDist.getNearestPoints().get(reducedFacet);

            if (transformation != null) {
                prevMeanD = transformation.getMeanD();
            }

            transformation = computeIcpTransformation(nearestPoints, distances, reducedFacet);
            transformations.get(transformedFacet).add(transformation);
            applyTransformation(transformedFacet, transformation);
            currentIteration ++;
        }        
    }
    
    /***********************************************************
     *  PRIVATE METHODS
     ***********************************************************/

    /**
     * Compute transformation parameters( translation, rotation and scale factor).
     * Based on old FIDENTIS implementation
     *
     * @param nearestPoints List of nearest points computed by Hausdorff distance.
     * @param distances list of distances computed by Hausdorff distance
     * @return Return ICP transformation which is represented by computed parameters.

     */
    private IcpTransformation computeIcpTransformation(List<Point3d> nearestPoints, List<Double> distances, MeshFacet transformedFacet) {
        List<MeshPoint>comparedPoints = transformedFacet.getVertices();
        double meanD = 0;
        double countOfNotNullPoints = 0;

        List<Point3d>centers = computeCenterBothFacets(nearestPoints, comparedPoints);
        Point3d mainCenter = centers.get(0);
        Point3d comparedCenter = centers.get(1);

        Matrix4d sumMatrix = new Matrix4d();
        Matrix4d multipleMatrix = new Matrix4d();

        for(int i = 0; i < comparedPoints.size(); i++) {
            if (nearestPoints.get(i) == null){
                continue;
            }
            // START computing Translation coordinates
            meanD += distances.get(i);

            Point3d relativeC = relativeCoordinate(comparedPoints.get(i).getPosition(),comparedCenter);
            Point3d relativeN = relativeCoordinate(nearestPoints.get(i),mainCenter);

            // START computing Rotation parameters
            multipleMatrix.mulTransposeLeft(sumMatrixComp(relativeC), sumMatrixMain(relativeN));
            sumMatrix.add(multipleMatrix);

            countOfNotNullPoints ++;
        }

        meanD /= countOfNotNullPoints;

        Vector3d translation = new Vector3d(
                mainCenter.x - comparedCenter.x,
                mainCenter.y - comparedCenter.y, 
                mainCenter.z - comparedCenter.z);

        // END computing translation parameter

        Quaternion q = new Quaternion(new EigenvalueDecomposition(sumMatrix));
        q.normalize();
        // END computing rotation parameter

        //computing SCALE parameter
        double scaleFactor = 0;
        if (scale) {
            double sxUp = 0;
            double sxDown = 0;
            Matrix4d rotationMatrix = q.toMatrix();
            for (int i = 0; i < nearestPoints.size(); i++) {
                if(nearestPoints.get(i) == null) {
                    continue;
                }

                Matrix4d matrixPoint = pointToMatrix(relativeCoordinate(nearestPoints.get(i), mainCenter));
                Matrix4d matrixPointCompare = pointToMatrix(relativeCoordinate(comparedPoints.get(i).getPosition(),comparedCenter));

                matrixPointCompare.mul(rotationMatrix);
                
                matrixPoint.mulTransposeLeft(matrixPoint, matrixPointCompare);
                matrixPointCompare.mulTransposeLeft(matrixPointCompare, matrixPointCompare);

                sxUp += matrixPoint.getElement(0, 0);
                sxDown += matrixPointCompare.getElement(0,0);
            }

            scaleFactor = sxUp / sxDown;

        } // end computing scale parameter

        return new IcpTransformation(translation, q, scaleFactor, meanD);
    }

    /**
     * Apply computed transformation to compared facet.
     *
     * @param transformedFacet Facet to be transformed
     * @param transformation Computed transformation.
     */
    public void applyTransformation(MeshFacet transformedFacet, IcpTransformation transformation) {
        transformedFacet.getVertices().parallelStream().forEach(
                p -> {
                    Point3d meshPointPosition = p.getPosition();
                    Quaternion rotation = new Quaternion(meshPointPosition.x, meshPointPosition.y, meshPointPosition.z, 1);
                    Quaternion rotationCopy = Quaternion.multiply(rotation, transformation.getRotation().getConjugate());
                    rotation = Quaternion.multiply(transformation.getRotation(), rotationCopy);
        
                    if(scale && !Double.isNaN(transformation.getScaleFactor())) {
                        meshPointPosition.set(
                                rotation.x * transformation.getScaleFactor() + transformation.getTranslation().x,
                                rotation.y * transformation.getScaleFactor() + transformation.getTranslation().y,
                                rotation.z * transformation.getScaleFactor() + transformation.getTranslation().z);
                    } else {
                        meshPointPosition.set(
                                rotation.x + transformation.getTranslation().x,
                                rotation.y + transformation.getTranslation().y ,
                                rotation.z + transformation.getTranslation().z);
                    }
                }
        );
    }

    /**
     * Compute center of both given objects which are represents by list of Vector3d.
     * Return list of two centers points.
     *
     * @param nearestPoints list of the nearest neighbours
     * @param comparedPoints list fo compared points
     * @return List with two points. The first one represents center of main facet
     * and the second one represents center of compared facet.
     */
    private List<Point3d> computeCenterBothFacets(List<Point3d> nearestPoints, List<MeshPoint> comparedPoints) {
        double xN = 0;
        double yN = 0;
        double zN = 0;

        double xC = 0;
        double yC = 0;
        double zC = 0;

        int countOfNotNullPoints = 0;
        List<Point3d> result = new ArrayList<>(2);

        for (int i = 0; i < nearestPoints.size(); i++) {

            if(nearestPoints.get(i) == null){
                continue;
            }

            xN += nearestPoints.get(i).x;
            yN += nearestPoints.get(i).y;
            zN += nearestPoints.get(i).z;

            xC += comparedPoints.get(i).getPosition().x;
            yC += comparedPoints.get(i).getPosition().y;
            zC += comparedPoints.get(i).getPosition().z;

            countOfNotNullPoints ++;
        }
        result.add(new Point3d(xN/countOfNotNullPoints, yN/countOfNotNullPoints, zN/countOfNotNullPoints));
        result.add(new Point3d(xC/countOfNotNullPoints, yC/countOfNotNullPoints, zC/countOfNotNullPoints));
        return result;
    }

    /**
     * Compute relative coordinate of given point according to given center.
     * Relative coordinates represents distance from the center of the mesh to vertex.
     *
     * @param p Point of which relative coordinates we want to be computed.
     * @param center Vector3d which represents center of object.
     * @return Vector3d which represents coordinates of new point according to center.
     */
    private Point3d relativeCoordinate(Point3d p, Point3d center) {
        return new Point3d(p.x - center.x, p.y - center.y, p.z - center.z);
    }

    /**
     * Compute sum matrix of given point. Given point represents compared point.
     *
     * @param p Compared point
     * @return Sum matrix of given point
     */
    private Matrix4d sumMatrixComp(Point3d p) {
        return new Matrix4d(
                0,   -p.x, -p.y, -p.z,
                p.x,  0,    p.z, -p.y,
                p.y, -p.z,  0,    p.x,
                p.z,  p.y, -p.x,  0);
    }
    /**
     * Compute sum matrix of given point. Given point is point from main facet.
     *
     * @param p Compared point
     * @return Sum matrix of given point
     */
    private Matrix4d sumMatrixMain(Point3d p) {
        return new Matrix4d(
                0,   -p.x, -p.y, -p.z,
                p.x,  0,   -p.z,  p.y,
                p.y,  p.z,  0,   -p.x,
                p.z, -p.y,  p.x,  0);
    }

    /**
     * Convert point (Vector3d) to matrix(Matrix4d).
     *
     * @param p Point, we want to be converted.
     * @return Matrix of point.
     */
    private Matrix4d pointToMatrix(Point3d p){
        return new Matrix4d(
                p.x, p.y, p.z, 1,
                0,   0,   0,   0, 
                0,   0,   0,   0,
                0,   0,   0,   0);
    }

    
}
