package cz.fidentis.analyst.icp;

import com.google.common.primitives.Doubles;
import cz.fidentis.analyst.kdtree.KdTree;
import cz.fidentis.analyst.mesh.MeshVisitor;
import cz.fidentis.analyst.mesh.core.MeshFacet;
import cz.fidentis.analyst.mesh.core.MeshModel;
import cz.fidentis.analyst.mesh.core.MeshPoint;
import cz.fidentis.analyst.visitors.mesh.HausdorffDistance;
import cz.fidentis.analyst.visitors.mesh.sampling.NoSampling;
import cz.fidentis.analyst.visitors.mesh.sampling.PointSampling;
import cz.fidentis.analyst.visitors.mesh.sampling.RandomSampling;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import javax.vecmath.Matrix4d;
import javax.vecmath.Point3d;
import javax.vecmath.Tuple4d;
import javax.vecmath.Vector3d;
import javax.vecmath.Vector4d;

/**
 * Visitor applying the Iterative Closest Point (ICP) algorithm to minimize 
 * Hausdorff distance of two triangular meshes. 
 * Inspected mesh facets are transformed (!) (vertices are moved) and the history 
 * of their transformations (transformation performed in each ICP iteration) is returned.
 * <p>
 * This visitor <strong>is not thread-safe</strong> due to efficiency reasons
 * and because of the algorithm principle, when it iterates multiple times through 
 * each inspected facet. Therefore, concurrent ICP transformation has to use 
 * an individual ICP visitor for each transformed (inspected) mesh facet.
 * Sequential inspection of multiple facets using a single visitor's instance 
 * is possible.
 * </p>
 * <p>
 * This implementation is based on 
 * <pre>
 * Ferkova, Z. Comparison and Analysis of Multiple 3D Shapes [online]. 
 * Brno, 2016 [cit. 2022-02-08]. 
 * Available from: <a href="https://is.muni.cz/th/wx40f/">https://is.muni.cz/th/wx40f/</a>. 
 * Master's thesis. Masaryk University, Faculty of Informatics.
 * </pre>
 * </p>
 * 
 * @author Maria Kocurekova
 * @author Radek Oslejsek
 * @author Zuzana Ferkova
 */
public class IcpTransformer extends MeshVisitor   {
    
    /**
     * Transformed mesh facets and their history of transformations.
     * Key = visited mesh facets. 
     * Value = History of transformations (transformation performed in each ICP iteration).
     */
    private final Map<MeshFacet, List<IcpTransformation>> transformations = new HashMap<>();
    
    /**
     * Maximal number of ICP iterations
     */
    private final int maxIteration;
    
    /**
     * Whether to scale mesh facets as well
     */
    private final boolean scale;
    
    /**
     * Acceptable error. When reached, then the ICP stops.
     */
    private final double error;
    
    /**
     * K-d tree of the primary triangular mesh(es).
     */
    private final KdTree primaryKdTree;
    
    private final PointSampling samplingStrategy;
    
    /**
     * Constructor for random sampling, which is the best performing downsampling strategy.
     * 
     * @param mainFacet Primary mesh facet. Must not be {@code null}. 
     * Inspected facets are transformed toward this primary mesh.
     * @param maxIteration Maximal number of ICP iterations (it includes computing 
     * new transformation and applying it). A number bigger than zero.
     * Reasonable number seems to be 10.
     * @param scale If {@code true}, then the scale factor is also computed.
     * @param error Acceptable error - a number bigger than or equal to zero. 
     * Mean distance of vertices is computed for each ICP iteration.
     * If the difference between the previous and current mean distances is less than the error,
     * then the ICP computation stops. Reasonable number seems to be 0.05.
     * @param numSamples Number of samples for downsampling. Use 1000 for best performance. Zero = no downsampling
     * @throws IllegalArgumentException if some parameter is wrong
     */
    public IcpTransformer(MeshFacet mainFacet, int maxIteration, boolean scale, double error, int numSamples) {
        this(new HashSet<>(Collections.singleton(mainFacet)), maxIteration, scale, error, (numSamples == 0) ? new NoSampling() : new RandomSampling(numSamples));
        if (mainFacet == null) {
            throw new IllegalArgumentException("mainFacet");
        }
    }
    
    /**
     * Constructor.
     * 
     * @param mainFacet Primary mesh facet. Must not be {@code null}. 
     * Inspected facets are transformed toward this primary mesh.
     * @param maxIteration Maximal number of ICP iterations (it includes computing 
     * new transformation and applying it). A number bigger than zero.
     * Reasonable number seems to be 10.
     * @param scale If {@code true}, then the scale factor is also computed.
     * @param error Acceptable error - a number bigger than or equal to zero. 
     * Mean distance of vertices is computed for each ICP iteration.
     * If the difference between the previous and current mean distances is less than the error,
     * then the ICP computation stops. Reasonable number seems to be 0.05.
     * @param strategy One of the reduction strategies. If {@code null}, then {@code NoUndersampling} is used.
     * @throws IllegalArgumentException if some parameter is wrong
     */
    public IcpTransformer(MeshFacet mainFacet, int maxIteration, boolean scale, double error, PointSampling strategy) {
        this(new HashSet<>(Collections.singleton(mainFacet)), maxIteration, scale, error, strategy);
        if (mainFacet == null) {
            throw new IllegalArgumentException("mainFacet");
        }
    }

    /**
     * Constructor.
     * 
     * @param mainFacets Primary mesh facets. Must not be {@code null}. 
     * Inspected facets are transformed toward these primary mesh. 
     * @param maxIteration Maximal number of ICP iterations (it includes computing 
     * new transformation and applying it). A number bigger than zero.
     * Reasonable number seems to be 10.
     * @param scale If {@code true}, then the scale factor is also computed.
     * @param error Acceptable error - a number bigger than or equal to zero. 
     * Mean distance of vertices is computed for each ICP iteration.
     * If the difference between the previous and current mean distances is less than the error,
     * then the ICP computation stops. Reasonable number seems to be 0.05.
     * @param strategy One of the reduction strategies. If {@code null}, then {@code NoUndersampling} is used.
     * @throws IllegalArgumentException if some parameter is wrong
     */
    public IcpTransformer(Set<MeshFacet> mainFacets, int maxIteration, boolean scale, double error, PointSampling strategy) {
        if (mainFacets == null) {
            throw new IllegalArgumentException("mainFacets");
        }
        if (maxIteration <= 0) {
            throw new IllegalArgumentException("maxIteration");
        }
        if (error < 0.0) {
            throw new IllegalArgumentException("error");
        }
        this.primaryKdTree = new KdTree(new ArrayList<>(mainFacets));
        this.error = error;
        this.maxIteration = maxIteration;
        this.scale = scale;
        this.samplingStrategy = (strategy == null) ? new NoSampling() : strategy;
    }

    /**
     * Constructor.
     * 
     * @param mainModel Primary mesh model. Must not be {@code null}. 
     * Inspected facets are transformed toward this primary mesh.
     * @param maxIteration Maximal number of ICP iterations (it includes computing 
     * new transformation and applying it). A number bigger than zero.
     * Reasonable number seems to be 10.
     * @param scale If {@code true}, then the scale factor is also computed.
     * @param error Acceptable error - a number bigger than or equal to zero. 
     * Mean distance of vertices is computed for each ICP iteration.
     * If the difference between the previous and current mean distances is less than the error,
     * then the ICP computation stops. Reasonable number seems to be 0.05.
     * @param strategy One of the reduction strategies. If {@code null}, then {@code NoUndersampling} is used.
     * @throws IllegalArgumentException if some parameter is wrong
     */
    public IcpTransformer(MeshModel mainModel, int maxIteration, boolean scale, double error, PointSampling strategy) {
        this(new HashSet<>(mainModel.getFacets()), maxIteration, scale, error, strategy);
        if (mainModel.getFacets().isEmpty()) {
            throw new IllegalArgumentException("mainModel");
        }
    }
    
    /**
     * Constructor.
     * 
     * @param primaryKdTree The k-d tree of the primary mesh. Must not be {@code null}. 
     * Inspected facets are transformed toward this primary mesh.
     * @param maxIteration Maximal number of ICP iterations (it includes computing 
     * new transformation and applying it). A number bigger than zero.
     * Reasonable number seems to be 10.
     * @param scale If {@code true}, then the scale factor is also computed.
     * @param error Acceptable error - a number bigger than or equal to zero. 
     * Mean distance of vertices is computed for each ICP iteration.
     * If the difference between the previous and current mean distances is less than the error,
     * then the ICP computation stops. Reasonable number seems to be 0.05.
     * @param strategy One of the reduction strategies. If {@code null}, then {@code NoUndersampling} is used.
     * @throws IllegalArgumentException if some parameter is wrong
     */
    public IcpTransformer(KdTree primaryKdTree, int maxIteration, boolean scale, double error, PointSampling strategy) {
        if (primaryKdTree == null) {
            throw new IllegalArgumentException("primaryKdTree");
        }
        if (maxIteration <= 0) {
            throw new IllegalArgumentException("maxIteration");
        }
        if (error < 0.0) {
            throw new IllegalArgumentException("error");
        }
        this.primaryKdTree = primaryKdTree;
        this.error = error;
        this.maxIteration = maxIteration;
        this.scale = scale;
        this.samplingStrategy = (strategy == null) ? new NoSampling() : strategy;
    }
    
    /**
     * Returns the history of transformations (transformation performed in each 
     * ICP iteration for each inspected mesh facet).
     * Keys in the map contain mesh facets that were inspected and transformed. 
     * For each transformed facet, a list of transformations to the primary mesh
     * is stored. The order of transformations corresponds to the order of ICP iterations, 
     * i.e., the i-th value is the transformation applied in the i-th iteration on the visited mesh facet.
     * It also means that the size of the list corresponds to the number of iterations 
     * performed for given mesh facet.
     * 
     * @return The history of transformations (transformation performed in each ICP iteration for each inspected mesh facet).
     */
    public Map<MeshFacet, List<IcpTransformation>> getTransformations() {
        return Collections.unmodifiableMap(transformations);
    }
    
    
    /**
     * This visitor is <strong>is not thread-safe</strong>.
     * Therefore, concurrent ICP transformations have to use an individual ICP visitor for each 
     * transformed (inspected) mesh facet.
     * 
     * @return {@code false}.
     */
    @Override
    public boolean isThreadSafe() {
        return false;
    }
    
    /**
     * Returns the maximal number of ICP iterations used for the computation.
     * 
     * @return the maximal number of ICP iterations used for the computation
     */
    public int getMaxIterations() {
        return this.maxIteration;
    }
    
    /**
     * Returns maximal acceptable error used for the computation.
     * 
     * @return maximal acceptable error used for the computation
     */
    public double getError() {
        return this.error;
    }
    
    /**
     * Returns {@code true} if the inspected mesh faces were also scaled.
     * 
     * @return {@code true} if the inspected mesh faces were also scaled.
     */
    public boolean getScale() {
        return this.scale;
    }
    
    /**
     * Returns k-d tree of the primary mesh to which other meshes has been transformed.
     * 
     * @return k-d tree of the primary mesh
     */
    public KdTree getPrimaryKdTree() {
        return this.primaryKdTree;
    }
    
    @Override
    public void visitMeshFacet(MeshFacet transformedFacet) {
        HausdorffDistance hausdorffDist = new HausdorffDistance(
                this.primaryKdTree, 
                HausdorffDistance.Strategy.POINT_TO_POINT, 
                false, // relative distance
                true,  // parallel computation
                false  // auto crop
        );
        
        //int numSamples = samplingStrategy.getNumDownsampledPoints(transformedFacet.getNumberOfVertices());
        //samplingStrategy.setRequiredSamples(200);
        MeshFacet reducedFacet = new UndersampledMeshFacet(transformedFacet, samplingStrategy);

        int currentIteration = 0;
        IcpTransformation transformation = null;
        
        if (!transformations.containsKey(transformedFacet)) {
            transformations.put(transformedFacet, new ArrayList<>(maxIteration));
        }

        double prevMeanD = Double.POSITIVE_INFINITY;

        while ((currentIteration < maxIteration) &&
                (Double.isInfinite(prevMeanD) || Math.abs(prevMeanD - Objects.requireNonNull(transformation).getMeanD()) > error )) {

            hausdorffDist.visitMeshFacet(reducedFacet); // repeated incpection re-calculates the distance
            List<Double> distances = hausdorffDist.getDistances().get(reducedFacet);
            List<Point3d> nearestPoints = hausdorffDist.getNearestPoints().get(reducedFacet);

            if (transformation != null) {
                prevMeanD = transformation.getMeanD();
            }

            transformation = computeIcpTransformation(nearestPoints, distances, reducedFacet.getVertices());
            transformations.get(transformedFacet).add(transformation);
            applyTransformation(transformedFacet, transformation);
            
            if (!samplingStrategy.isBackedByOrigMesh()) { // samples have to be transfomed as well
                applyTransformation(reducedFacet, transformation);
            }
            
            currentIteration++;
            
            //if (currentIteration >= 3) {
            //    samplingStrategy.setRequiredSamples(numSamples);
            //    reducedFacet = new UndersampledMeshFacet(transformedFacet, samplingStrategy);
            //}
        }
    }
    
    /***********************************************************
     *  PRIVATE METHODS
     ***********************************************************/
    
    /**
     * Compute transformation parameters (translation, rotation, and scale factor).
     * Based on old FIDENTIS implementation
     *
     * @param nearestPoints List of nearest points computed by Hausdorff distance.
     * @param distances list of distances computed by Hausdorff distance
     * @return Return ICP transformation which is represented by computed parameters.
     */
    private IcpTransformation computeIcpTransformation(List<Point3d> nearestPoints, List<Double> distances, List<MeshPoint>trPoints) {
        //
        // Compute translation and mean distance from centroids:
        //
        double meanD = 0;
        Point3d mainCenter = new Point3d(0, 0, 0);
        Point3d trCenter = new Point3d(0, 0, 0);
        int countOfNotNullPoints = 0;
        for (int i = 0; i < nearestPoints.size(); i++) {
            if (nearestPoints.get(i) != null && Doubles.isFinite(distances.get(i))) {
                mainCenter.add(nearestPoints.get(i));
                trCenter.add(trPoints.get(i).getPosition());
                countOfNotNullPoints++;
                meanD += distances.get(i);
            }
        }
        mainCenter.scale(1.0 / countOfNotNullPoints);
        trCenter.scale(1.0 / countOfNotNullPoints);

        Vector3d translation = new Vector3d(
                mainCenter.x - trCenter.x,
                mainCenter.y - trCenter.y, 
                mainCenter.z - trCenter.z
        );
        
        meanD /= countOfNotNullPoints;

        //
        // Compute rotation:
        //
        Matrix4d sumMat = new Matrix4d();
        for (int i = 0; i < trPoints.size(); i++) {
            if (nearestPoints.get(i) != null && Doubles.isFinite(distances.get(i))) {
                //Matrix4d tmpMat = sumMatrixComp(relativeCoord4d(trPoints.get(i).getPosition(), trCenter));
                //tmpMat.mul(sumMatrixMain(relativeCoord4d(nearestPoints.get(i), mainCenter)));
                //sumMat.add(tmpMat);            
                sumMat.add(multMat(
                        relativeCoord4d(trPoints.get(i).getPosition(), trCenter),
                        relativeCoord4d(nearestPoints.get(i), mainCenter)
                ));
            }
        }
        
        Quaternion q = new Quaternion(new EigenvalueDecomposition(sumMat));
        q.normalize();

        //
        // compute scale
        //
        double scaleFactor = 1.0;
        if (scale) {
            double sxUp = 0;
            double sxDown = 0;
            Matrix4d rotMat = q.toMatrix();
            for (int i = 0; i < nearestPoints.size(); i++) {
                if (nearestPoints.get(i) != null && Doubles.isFinite(distances.get(i))) {
                    Vector4d relC = relativeCoord4d(nearestPoints.get(i), mainCenter);
                    Vector4d relA = relativeCoord4d(trPoints.get(i).getPosition(), trCenter);
                    rotMat.transform(relA);                
                    sxUp += relC.dot(relA);
                    sxDown += relA.dot(relA);
                }
            }
            scaleFactor = sxUp / sxDown;
        }

        return new IcpTransformation(translation, q, scaleFactor, meanD);
    }

    /**
     * Apply computed transformation to compared facet.
     *
     * @param transformedFacet Facet to be transformed
     * @param transformation Computed transformation.
     */
    private void applyTransformation(MeshFacet transformedFacet, IcpTransformation transformation) {
        transformedFacet.getVertices().parallelStream()
                .forEach(p -> {
                    p.setPosition(transformation.transformPoint(p.getPosition(), scale));
                    if (p.getNormal() != null) {
                        p.setNormal(transformation.transformNormal(p.getNormal()));
                    }
                }
        );
    }

    private Vector4d relativeCoord4d(Point3d p, Point3d center) {
        return new Vector4d(p.x - center.x, p.y - center.y, p.z - center.z, 1.0);
    }

    /**
     * Compute sum matrix of given point. Given point represents compared point.
     *
     * @param p Compared point
     * @return Sum matrix of given point
     */
    @Deprecated
    private Matrix4d sumMatrixComp(Tuple4d p) {
        return new Matrix4d(
                0,    p.x,  p.y,  p.z,
               -p.x,  0,   -p.z,  p.y,
               -p.y,  p.z,  0,   -p.x,
               -p.z, -p.y,  p.x,  0);
    }
    
    /**
     * Compute sum matrix of given point. Given point is point from main facet.
     *
     * @param p Compared point
     * @return Sum matrix of given point
     */
    @Deprecated
    private Matrix4d sumMatrixMain(Tuple4d p) {
        return new Matrix4d(
                0,   -p.x, -p.y, -p.z,
                p.x,  0,   -p.z,  p.y,
                p.y,  p.z,  0,   -p.x,
                p.z, -p.y,  p.x,  0);
    }
    
    /**
     * Fast implementation of <pre>sumMatrixComp(a).mult(sumMatrixMain(b));</pre>
     */
    private Matrix4d multMat(Tuple4d a, Tuple4d b) {
        double axbx = a.x * b.x;
        double axby = a.x * b.y;
        double axbz = a.x * b.z;
        double aybx = a.y * b.x;
        double ayby = a.y * b.y;
        double aybz = a.y * b.z;
        double azbx = a.z * b.x;
        double azby = a.z * b.y;
        double azbz = a.z * b.z;
        
        return new Matrix4d(
                axbx+ayby+azbz,       aybz-azby,      -axbz+azbx,       axby-aybx,
                    -azby+aybz,  axbx-azbz-ayby,       axby+aybx,       axbz+azbx,
                     azbx-axbz,       aybx+axby,  ayby-azbz-axbx,       aybz+azby,
                    -aybx+axby,       azbx+axbz,       azby+aybz,  azbz-ayby-axbx
        );
    }
}
