package cz.fidentis.analyst.icp;

import cz.fidentis.analyst.mesh.core.MeshFacet;
import cz.fidentis.analyst.mesh.core.MeshFacetImpl;
import cz.fidentis.analyst.mesh.core.MeshPoint;
import cz.fidentis.analyst.visitors.mesh.HausdorffDistance;

import javax.vecmath.Matrix4d;
import javax.vecmath.Quat4d;
import javax.vecmath.Vector3d;
import java.util.ArrayList;
import java.util.Collections;
import java.util.LinkedList;
import java.util.List;
import java.util.Objects;

/**
 * Icp represents class for computing the ICP (Iterative closest point) algorithm.
 * It is used to minimize the distance between two {@link cz.fidentis.analyst.mesh.core.MeshFacet}s.
 * This class does not have parameterized constructor, for running ICP algorithm use public method getTransformedFacet.
 * (In case you want change default setup use public method setParameters.)
 *
 * @author Maria Kocurekova
 */
public class Icp {

    private final List<IcpTransformation> transformations = new LinkedList<>();
    private MeshFacet transformedFacet = null;
    private  int maxIteration = 10;
    private boolean scale;
    private double error = 0.05;


    /**
     * Change default setup for max count of iteration, scale and error.
     *
     * @param maxIteration Max count of running iterations
     *                     (it includes computing new transformation and applying it).
     * @param scale In case there is scale, scale factor is also computed.
     * @param e Max error which is allowed.
     */
    public void setParameters(int maxIteration, boolean scale, double e){
        this.maxIteration = maxIteration;
        this.scale = scale;
        this.error = e;
    }

    /**
     * GetTransformedFacet represents running of the ICP algorithm. It use methods for computing and applying transformations.
     * It is running until count of iterations is lower than maxIteration or until we reach lower distance
     * between two objects than allowed error.
     *
     * @param mainFacet MainFacet represents fix faced.
     *                  We want to align computedFacet according to MainFacet.
     * @param comparedFacet Facet we want to be aligned.
     * @return Return transformed facet with new coordinates.
     */
    public MeshFacet getTransformedFacet(MeshFacet mainFacet, MeshFacet comparedFacet){
        transformedFacet = new MeshFacetImpl(comparedFacet);

        HausdorffDistance hausdorffDist = new HausdorffDistance(mainFacet, HausdorffDistance.Strategy.POINT_TO_POINT, false, false);

        int currentIteration = 0;
        IcpTransformation transformation = null;

        double prevMeanD = Double.POSITIVE_INFINITY;

        while ((currentIteration < maxIteration) &&
                (Double.isInfinite(prevMeanD) ||  Math.abs(prevMeanD - Objects.requireNonNull(transformation).getMeanD() ) > error )){

            hausdorffDist.visitMeshFacet(transformedFacet);
            List<Double> distances = hausdorffDist.getDistances().get(transformedFacet);
            List<Vector3d> nearestPoints = hausdorffDist.getNearestPoints().get(transformedFacet);

            if(transformation != null){
                prevMeanD = transformation.getMeanD();
            }

            transformation = computeIcpTransformation(nearestPoints, distances);
            transformations.add(transformation);
            applyTransformation(transformation);
            currentIteration ++;
        }

        return transformedFacet;
    }
    
    /**
     * Getter for all transformations.
     * 
     * @return List of IcpTransformation.
     */
    public List<IcpTransformation> getTransformations() {
        return Collections.unmodifiableList(transformations);
    }


    /***********************************************************
     *  PRIVATE METHODS
     ***********************************************************/
    
    /**
     * Compute transformation parameters( translation, rotation and scale factor).
     *
     * @param nearestPoints List of nearest points computed by Hausdorff distance.
     * @param distances list of distances computed by Hausdorff distance
     * @return Return Icp transformation which is represented by computed parameters.
     */
    private IcpTransformation computeIcpTransformation(List<Vector3d> nearestPoints, List<Double> distances) {
        List<MeshPoint>comparedPoints = transformedFacet.getVertices();
        double x, y, z;
        double meanX = 0;
        double meanY = 0;
        double meanZ = 0;
        double meanD = 0;
        double countOfNotNullPoints = 0;

        List<Vector3d>centers = computeCenterBothFacets(nearestPoints, comparedPoints);
        Vector3d mainCenter = centers.get(0);
        Vector3d comparedCenter = centers.get(1);
        Matrix4d sumMatrix = new Matrix4d();
        Matrix4d multipleMatrix = new Matrix4d();

        for(int i = 0; i < comparedPoints.size(); i++) {
            if (nearestPoints.get(i) == null){
                continue;
            }
            // START computing Translation coordinates
            x = nearestPoints.get(i).x - comparedPoints.get(i).getPosition().x;
            y = nearestPoints.get(i).y - comparedPoints.get(i).getPosition().y;
            z = nearestPoints.get(i).z - comparedPoints.get(i).getPosition().z;

            meanX += x;
            meanY += y;
            meanZ += z;
            meanD += distances.get(i);

            // START computing Rotation parameters
            //multipleMatrix =
            //= (transpose (sumMatrix of relative coordinates of compared point)) x (sumMatrix of relative coordinates of nearest point)
            multipleMatrix.mulTransposeLeft(sumMatrixComp(relativeCoordinate(comparedPoints.get(i).getPosition(),comparedCenter)), sumMatrixMain(relativeCoordinate(nearestPoints.get(i),mainCenter)));
            sumMatrix.add(multipleMatrix);

            countOfNotNullPoints ++;
        }

        meanD /= countOfNotNullPoints;
        meanX /= countOfNotNullPoints;
        meanY /= countOfNotNullPoints;
        meanZ /= countOfNotNullPoints;
        // END computing translation parameter

        Quat4d rotation = new Quat4d();
        rotation.set(sumMatrix);
        rotation.normalize();
        // END computing rotation parameter

        //computing SCALE parameter
        double sxUp = 0;
        double scaleFactor = 0;
        double sxDown = 0;

        if (scale) {
            Matrix4d rotationMatrix = new Matrix4d();
            rotationMatrix.set(rotation);
            Matrix4d matrixPoint, matrixPointCompare;

            for (int i = 0; i < nearestPoints.size(); i++) {
                if(nearestPoints.get(i) == null) {
                    continue;
                }
                matrixPoint = pointToMatrix(relativeCoordinate(nearestPoints.get(i), mainCenter));
                matrixPointCompare = pointToMatrix(relativeCoordinate(comparedPoints.get(i).getPosition(),comparedCenter));
                matrixPointCompare.mul(rotationMatrix);

                matrixPoint.transpose();
                matrixPoint.mul(matrixPointCompare);
                sxUp += matrixPoint.getElement(0, 0);

                matrixPointCompare.transpose();
                matrixPointCompare.mul(matrixPointCompare);
                sxDown += matrixPointCompare.getElement(0,0);
            }

            scaleFactor = sxUp / sxDown;

        } // end computing scale parameter

        return new IcpTransformation(new Vector3d(meanX, meanY, meanZ), rotation, scaleFactor, meanD);
    }

    /**
     * Apply computed transformation to compared facet.
     *
     * @param transformation Computed transformation.
     */
    private void applyTransformation(IcpTransformation transformation) {
        Vector3d meshPointPosition;
        Quat4d rotationCopy = new Quat4d();
        Quat4d conjugateRotation = new Quat4d();


        for (MeshPoint comparedPoint : transformedFacet.getVertices()) {
            meshPointPosition = comparedPoint.getPosition();

            Quat4d point = new Quat4d(meshPointPosition.x, meshPointPosition.y, meshPointPosition.z, 1);

            if (transformedFacet.getVertices().size() > 1) {
                conjugateRotation.conjugate(transformation.getRotation());
                rotationCopy.mul(point, conjugateRotation);
                rotationCopy.mul(transformation.getRotation(), rotationCopy);
            } else {
                rotationCopy = point;
            }

            if(scale && !Double.isNaN(transformation.getScaleFactor())) {
                meshPointPosition.set(rotationCopy.x * transformation.getScaleFactor() + transformation.getTranslation().x +  meshPointPosition.x,
                        rotationCopy.y * transformation.getScaleFactor() + transformation.getTranslation().y +  meshPointPosition.y,
                        rotationCopy.z * transformation.getScaleFactor() + transformation.getTranslation().z +  meshPointPosition.z);
            } else {
                meshPointPosition.set(rotationCopy.x+ transformation.getTranslation().x + meshPointPosition.x,
                        rotationCopy.y + transformation.getTranslation().y + meshPointPosition.y,
                        rotationCopy.z + transformation.getTranslation().z + meshPointPosition.z);
            }

        }
    }

    /**
     * Compute center of both given objects which are represents by list of Vector3d.
     * Return list of two centers points.
     *
     * @param nearestPoints list of the nearest neighbours
     * @param comparedPoints list fo compared points
     * @return List with two points. The first one represents center of main facet
     * and the second one represents center of compared facet.
     */
    private List<Vector3d> computeCenterBothFacets(List<Vector3d> nearestPoints, List<MeshPoint> comparedPoints) {
        double xN = 0;
        double yN = 0;
        double zN = 0;

        double xC = 0;
        double yC = 0;
        double zC = 0;

        int countOfNotNullPoints = 0;
        List<Vector3d> result = new ArrayList<>(2);

        for (int i = 0; i < nearestPoints.size(); i++) {

            if(nearestPoints.get(i) == null){
                continue;
            }

            xN += nearestPoints.get(i).x;
            yN += nearestPoints.get(i).y;
            zN += nearestPoints.get(i).z;

            xC += comparedPoints.get(i).getPosition().x;
            yC += comparedPoints.get(i).getPosition().y;
            zC += comparedPoints.get(i).getPosition().z;

            countOfNotNullPoints ++;
        }
        result.add(new Vector3d(xN/countOfNotNullPoints,yN/countOfNotNullPoints,zN/countOfNotNullPoints));
        result.add(new Vector3d(xC/countOfNotNullPoints,yC/countOfNotNullPoints,zC/countOfNotNullPoints));
        return result;
    }

    /**
     * Compute relative coordinate of given point according to given center.
     * Relative coordinates represents distance from the center of the mesh to vertex.
     *
     * @param p Point of which relative coordinates we want to be computed.
     * @param center Vector3d which represents center of object.
     * @return Vector3d which represents coordinates of new point according to center.
     */
    private Vector3d relativeCoordinate(Vector3d p, Vector3d center) {
        return new Vector3d(p.x - center.x,p.y - center.y, p.z - center.z);
    }

    /**
     * Compute sum matrix of given point. Given point represents compared point.
     *
     * @param p Compared point
     * @return Sum matrix of given point
     */
    private Matrix4d sumMatrixComp(Vector3d p) {
        return new Matrix4d(0, -p.x, -p.y, -p.z,
                p.x, 0, p.z, -p.y,
                p.y, -p.z, 0, p.x,
                p.z, p.y, -p.x,0 );
    }


    /**
     * Compute sum matrix of given point. Given point is point from main facet.
     *
     * @param p Compared point
     * @return Sum matrix of given point
     */
    private Matrix4d sumMatrixMain(Vector3d p) {
        return new Matrix4d(0, -p.x, -p.y, -p.z,
                p.x, 0, -p.z, p.y,
                p.y, p.z, 0, -p.x,
                p.z, -p.y, p.x, 0);
    }

    /**
     * Convert point (Vector3d) to matrix(Matrix4d).
     *
     * @param p Point, we want to be converted.
     * @return Matrix of point.
     */
    private Matrix4d pointToMatrix(Vector3d p){
        return new Matrix4d(p.x, p.y, p.z, 1,
                0, 0, 0, 0, 0, 0,
                0,0, 0, 0, 0, 0);
    }

}