package cz.fidentis.analyst.symmetry;

import cz.fidentis.analyst.grid.UniformGrid3d;
import cz.fidentis.analyst.grid.UniformGrid4d;
import cz.fidentis.analyst.mesh.core.MeshFacet;
import cz.fidentis.analyst.mesh.core.MeshPoint;
import cz.fidentis.analyst.mesh.core.MeshPointImpl;
import cz.fidentis.analyst.visitors.mesh.sampling.PointSampling;
import java.util.ArrayList;
import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;
import javax.vecmath.Point3d;
import javax.vecmath.Vector3d;

/**
 * A robust implementation of symmetry plane estimation.
 * The code is based based on the
 * https://link.springer.com/article/10.1007/s00371-020-02034-w paper.
 * This estimator <b>works with point clouds</b> (does not require manifold triangle mesh).
 * It has the following properties:
 * <ul>
 * <li>The Wendland’s similarity functions is used to get best candidate planes.</li>
 * <li>No additional weights are used.</li>
 * <li>The computation is accelerated by using a Uniform Grid for pruning candidate planes
 * and using two level downsampling: the radical downsampling for the generation 
 * of candidate planes (cca. 100 points), and less radical (cca. 1000 point) 
 * downsampling for the selection of the best candidate(s).</li>
 * <li>Best results are achieved with Uniform Grid sampling with 100 and 1000 
 * points (for the two phases).</li>
 * </ul>
 * 
 * @author Radek Oslejsek
 */
public class SymmetryEstimatorRobust extends SymmetryEstimator {

    private final PointSampling samplingStrategy;
    private final int samplingLimit1;
    private final int samplingLimit2;
    
    private Plane symmetryPlane;
    
    
    /**
     * Constructor.
     * 
     * @param samplingStrategy Downsampling strategy. Must not be {@code null}
     *        Use {@code NoSumpling} strategy to avoid downsampling.
     * @param samplingLimit1 Desired number of samples for finding candidate planes. 
     * @param samplingLimit2 Desired number of samples for finding the best candidate. 
     */
    public SymmetryEstimatorRobust(PointSampling samplingStrategy, int samplingLimit1, int samplingLimit2) {
        if (samplingStrategy == null) {
            throw new IllegalArgumentException("samplingStrategy");
        }
        this.samplingStrategy = samplingStrategy;
        this.samplingLimit1 = samplingLimit1;
        this.samplingLimit2 = samplingLimit2;
    }
    
    /**
     * Computes and returns the symmetry plane. The plane is computed only once, 
     * then the same instance is returned until the visitor is applied to another mesh.
     * 
     * @return the symmetry plane or {@code null}
     */
    public Plane getSymmetryPlane() {
        if (symmetryPlane == null) {
            this.calculateSymmetryPlane();
        }
        return symmetryPlane;
    }
    
    @Override
    public void visitMeshFacet(MeshFacet facet) {
        samplingStrategy.visitMeshFacet(facet);
    }
    
    /**
     * Calculates the symmetry plane.
     */
    protected void calculateSymmetryPlane() {
        //
        // Phase 1: Downsample the mesh, transform it so that the centroid is 
        //          in the space origin, and then compute candidate planes
        //
        samplingStrategy.setRequiredSamples(samplingLimit1);
        List<MeshPoint> meshSamples = samplingStrategy.getSamples();
        Point3d origCentroid = new MeshPointImpl(meshSamples).getPosition();
        Set<SymmetryPlane> candidates = generateCandidates(meshSamples, origCentroid);
        
        //
        // Phase 2: Downsample the mesh again (ususally to more samples than before), 
        //          transform them in the same way as before, and measure their symmetry
        //          with respect to individual candiate planes.
        //
        samplingStrategy.setRequiredSamples(samplingLimit2);
        meshSamples = samplingStrategy.getSamples();
        measureSymmetry(meshSamples, origCentroid, candidates);
        
        //
        // Phase 3: "Adjust" the best 5 candidate planes so that they really 
        //          represent the local maxima using 
        //          the quasi-Newton optimization method L-BFGS
        // 
        // TO DO: candidates.stream().sorted().limit(5).forEach(optimize);
        
        //
        // Phase 4: Finally, get the best plane and move it back
        // 
        this.symmetryPlane = candidates.stream().sorted().limit(1).findAny().orElse(null);
        if (this.symmetryPlane != null) {
            Point3d planePoint = this.symmetryPlane.getPlanePoint();
            planePoint.add(origCentroid);
            Vector3d normal = this.symmetryPlane.getNormal();
            double dist = ((normal.x * planePoint.x) + (normal.y * planePoint.y) + (normal.z * planePoint.z))
                    / Math.sqrt(normal.dot(normal)); // distance of tranformed surface point in the plane's mormal direction
            this.symmetryPlane = new Plane(normal, dist);
        }
    }
    
    /**
     * Copies mesh samples, moves them to the space origin, and then computes candidate planes.
     * 
     * @param meshSamples Downsampled mesh
     * @param centroid Centroid of the downsampled mesh
     * @return Candidate planes
     */
    protected Set<SymmetryPlane> generateCandidates(List<MeshPoint> meshSamples, Point3d centroid) {
        ProcessedCloud cloud = new ProcessedCloud(meshSamples, centroid);
        UniformGrid4d<SymmetryPlane> planesCache = new UniformGrid4d<>(SymmetryPlane.GRID_SIZE);
        
        for (int i = 0; i < cloud.points.size(); i++) {
            for (int j = i; j < cloud.points.size(); j++) { // from i !!!
                if (i == j) {
                    continue;
                }
                SymmetryPlane candPlane = new SymmetryPlane(
                        cloud.points.get(i).getPosition(),
                        cloud.points.get(j).getPosition(),
                        cloud.avgDistance);
                SymmetryPlane closestPlane = candPlane.getClosestPlane(planesCache);
                if (closestPlane == null) { // add
                    planesCache.store(candPlane.getEstimationVector(), candPlane);
                } else { // replace with averaged plane 
                    SymmetryPlane avgPlane = new SymmetryPlane(closestPlane, candPlane);
                    planesCache.remove(closestPlane.getEstimationVector(), closestPlane);
                    planesCache.store(avgPlane.getEstimationVector(), avgPlane);
                }
            }
        }

        return planesCache.getAll().stream()
                .filter(plane -> plane.getNumAverages() >= 4)
                .collect(Collectors.toSet()); 
    }
    
    /**
     * Copies mesh samples, moves them to the space origin, and then measures the quality of 
     * candidate planes. The results are stored in the candidate planes.
     * 
     * @param meshSamples Downsampled mesh
     * @param centroid Centroid of the downsampled mesh
     * @param candidates Candidate planes
     */
    protected void measureSymmetry(List<MeshPoint> meshSamples, Point3d centroid, Set<SymmetryPlane> candidates) {
        ProcessedCloud cloud = new ProcessedCloud(meshSamples, centroid);
        double alpha = 15.0 / cloud.avgDistance;
        UniformGrid3d<MeshPoint> samplesGrid = new UniformGrid3d<>(2.6 / alpha, cloud.points, (MeshPoint p) -> p.getPosition());
        candidates.parallelStream().forEach(plane -> {
            plane.measureSymmetry(cloud.points, samplesGrid, alpha);
        });
    }
    
    
    
    
    /********************************************************************
     * A helper class that copies input mesh point and moves them 
     * so that the given centroid is in the space origin.
     * 
     * @author Radek Oslejsek
     */
    protected class ProcessedCloud {

        private List<MeshPoint> points;
        private double avgDistance;

        /**
         * Moves orig points so that the centroid is in the space origin, copies
         * them into a new list. Also, computes the average distance of orig
         * points to the centroid (= average distance of shifted points into the
         * space origin).
         *
         * @param centroid Centroid.
         * @param points Original mesh point
         */
        ProcessedCloud(List<MeshPoint> points, Point3d centroid) {
            this.points = new ArrayList<>(points.size());
            for (int i = 0; i < points.size(); i++) {
                MeshPoint mp = points.get(i);
                Point3d p = new Point3d(mp.getPosition());
                p.sub(centroid);
                this.points.add(new MeshPointImpl(p, mp.getNormal(), mp.getTexCoord(), mp.getCurvature()));
                avgDistance += Math.sqrt(p.x * p.x + p.y * p.y + p.z * p.z) / points.size();
            }
        }
    }
}
