package cz.fidentis.analyst.visitors.mesh;

import cz.fidentis.analyst.mesh.MeshVisitor;
import cz.fidentis.analyst.mesh.core.MeshFacet;
import cz.fidentis.analyst.mesh.core.MeshPoint;
import cz.fidentis.analyst.mesh.core.MeshTriangle;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import javax.vecmath.Vector3d;

/**
 * Abstract class for algorithms calculating curvatures of mesh vertices.
 * @see https://computergraphics.stackexchange.com/questions/1718/what-is-the-simplest-way-to-compute-principal-curvature-for-a-mesh-triangle
 * @see http://rodolphe-vaillant.fr/?e=20
 * <p>
 * <b>All curvature algorithms suppose that the triangle vertices are oriented clockwise!</b>
 * </p>
 * 
 * @author Natalia Bebjakova
 * @author Radek Oslejsek
 */
public class Curvature extends MeshVisitor {
    
    private final Map<MeshFacet, List<Double>> gaussian = new HashMap<>();
    private final Map<MeshFacet, List<Double>> mean = new HashMap<>();
    private final Map<MeshFacet, List<Double>> minPrincipal = new HashMap<>();
    private final Map<MeshFacet, List<Double>> maxPrincipal = new HashMap<>();
    
    /**
     * Returns Gaussian curvatures for all inspected mesh facets. The order corresponds to
     * the order of vertices, i.e., i-th value represents the curvature of i-th mesh vertex.
     * 
     * @return Gaussian curvatures of inspected mesh facets.
     */
    public Map<MeshFacet, List<Double>> getGaussianCurvatures() {
        return Collections.unmodifiableMap(gaussian);
    }
    
    /**
     * Returns mean curvatures for all inspected mesh facets. The order corresponds to
     * the order of vertices, i.e., i-th value represents the curvature of i-th mesh vertex.
     * 
     * @return Mean curvatures of inspected mesh facets.
     */
    public Map<MeshFacet, List<Double>> getMeanCurvatures() {
        return Collections.unmodifiableMap(mean);
    }
    
    /**
     * Returns minimum principal curvatures for all inspected mesh facets. The order corresponds to
     * the order of vertices, i.e., i-th value represents the curvature of i-th mesh vertex.
     * 
     * @return Minimum principal curvatures of inspected mesh facets.
     */
    public Map<MeshFacet, List<Double>> getMinPrincipalCurvatures() {
        return Collections.unmodifiableMap(minPrincipal);
    }
    
    /**
     * Returns maximu principal curvatures for all inspected mesh facets. The order corresponds to
     * the order of vertices, i.e., i-th value represents the curvature of i-th mesh vertex.
     * 
     * @return Maximum principal curvatures of inspected mesh facets.
     */
    public Map<MeshFacet, List<Double>> getMaxPrincipalCurvatures() {
        return Collections.unmodifiableMap(maxPrincipal);
    }
    
    @Override
    public void visitMeshFacet(final MeshFacet facet) {
        synchronized (this) {
            if (gaussian.containsKey(facet)) {
                return; // already visited facet
            }
            gaussian.put(facet, new ArrayList<>());
            mean.put(facet, new ArrayList<>());
            minPrincipal.put(facet, new ArrayList<>());
            maxPrincipal.put(facet, new ArrayList<>());
        }
        
        //final List<MeshTriangle> triangles = facet.getTriangles();
        final List<CurvTriangle> triangles = precomputeTriangles(facet);
        for (int vertA = 0; vertA < facet.getNumberOfVertices(); vertA++) {
            List<Integer> neighbouringTriangles = facet.getCornerTable().getTriangleIndexesByVertexIndex(vertA);
            
            if (facet.vertexIsBoundary(vertA)) {
                this.gaussian.get(facet).add(0.0);
                this.mean.get(facet).add(0.0);
                this.minPrincipal.get(facet).add(0.0);
                this.maxPrincipal.get(facet).add(0.0);
                continue;
            }
        
            if (neighbouringTriangles.isEmpty()) {
                this.gaussian.get(facet).add(Double.NaN);
                this.mean.get(facet).add(Double.NaN);
                this.minPrincipal.get(facet).add(Double.NaN);
                this.maxPrincipal.get(facet).add(Double.NaN);
                continue;
            }
        
            double sumArea = 0.0;
            double sumAngles = 0.0;
            Vector3d pointSum = new Vector3d();
        
            // for all surrounding triangles:
            for (int i = 0; i < neighbouringTriangles.size(); i++) {
                CurvTriangle ctri = triangles.get(neighbouringTriangles.get(i));
                CurvTriangle tNext = triangles.get(neighbouringTriangles.get((i + 1) % neighbouringTriangles.size()));
            
                sumArea += computeArea(ctri, facet.getVertex(vertA));
                sumAngles +=  ctri.alpha(facet.getVertex(vertA));
            
                Vector3d aux = new Vector3d(ctri.vertC(facet.getVertex(vertA)));
                aux.sub(ctri.vertA(facet.getVertex(vertA)));
                aux.scale(ctri.betaCotan(facet.getVertex(vertA)) + tNext.gammaCotan(facet.getVertex(vertA)));
                pointSum.add(aux);
            }
        
            double gaussVal = (2.0 * Math.PI - sumAngles) / sumArea;
            double meanVal = 0.25 * sumArea * pointSum.length();
            double delta = Math.max(0, Math.pow(meanVal, 2) - gaussVal);
            
            this.gaussian.get(facet).add(gaussVal);
            this.mean.get(facet).add(meanVal);
            this.minPrincipal.get(facet).add(meanVal - Math.sqrt(delta));
            this.maxPrincipal.get(facet).add(meanVal + Math.sqrt(delta));
        }
    }
    
    protected List<CurvTriangle> precomputeTriangles(MeshFacet facet) {
        List<CurvTriangle> ret = new ArrayList<>(facet.getNumTriangles());
        for (MeshTriangle tri: facet) {
            ret.add(new CurvTriangle(facet, tri));
        }
        return ret;
    }
    
    protected double computeArea(CurvTriangle ctri, MeshPoint vertA) {
        double alpha = ctri.alpha(vertA);
        double beta = ctri.beta(vertA);
        double gamma = ctri.gamma(vertA);
            
        double piHalf = Math.PI / 2.0; // 90 degrees
        if (alpha >= piHalf || beta >= piHalf || gamma >= piHalf) { // check for obtuse angle
            return (alpha > piHalf) ? ctri.area() / 2.0 : ctri.area() / 4.0;
        } else {
            double cotBeta = ctri.betaCotan(vertA);
            double cotGamma = ctri.gammaCotan(vertA);
            double ab = ctri.lengthSquaredAB(vertA);
            double ac = ctri.lengthSquaredAC(vertA);
            return (ab * cotGamma + ac * cotBeta) / 8.0;
        }
    }
    
    /**
     * Helper class that caches triangle characteristics used multiples times during the curvature computation.
     * <ul>
     * <li>A = central point of the 1-ring neighborhood</li>
     * <li>B = point "on the left" (previous point in the clockwise orientation of the triangle</li>
     * <li>C = point "on the right" (next point in the clockwise orientation of the triangle</li>
     * </ul>
     * 
     * @author Radek Oslejsek
     */
    protected class CurvTriangle {
        
        private final MeshFacet facet;
        private final MeshTriangle tri;
        
        private final double v1Angle;
        private final double v2Angle;
        private final double v3Angle;
        
        private final double v1AngleCotang;
        private final double v2AngleCotang;
        private final double v3AngleCotang;
        
        public CurvTriangle(MeshFacet facet, MeshTriangle tri) {
            this.facet = facet;
            this.tri = tri;
            
            Vector3d a = new Vector3d(facet.getVertex(tri.index3).getPosition());
            a.sub(facet.getVertex(tri.index2).getPosition());
            
            Vector3d b = new Vector3d(facet.getVertex(tri.index1).getPosition());
            b.sub(facet.getVertex(tri.index3).getPosition());
            
            Vector3d c = new Vector3d(facet.getVertex(tri.index2).getPosition());
            c.sub(facet.getVertex(tri.index1).getPosition());
            
            a.normalize();
            b.normalize();
            c.normalize();
            
            b.scale(-1.0);
            double cos1 = c.dot(b);
            
            c.scale(-1.0);
            double cos2 = a.dot(c);
            
            a.scale(-1.0);
            b.scale(-1.0);
            double cos3 = a.dot(b);
            
            this.v1Angle = Math.acos(cos1);
            this.v2Angle = Math.acos(cos2);
            this.v3Angle = Math.acos(cos3);
            
            this.v1AngleCotang = 1.0 / Math.tan(v1Angle);
            this.v2AngleCotang = 1.0 / Math.tan(v2Angle);
            this.v3AngleCotang = 1.0 / Math.tan(v3Angle);
        }
        
        public double area() {
            return 0.5 * lengthSquaredAC(tri.vertex1) * lengthSquaredAB(tri.vertex1) * Math.sin(v1Angle);
        }
        
        /**
         * Returns vertex A. 
         * 
         * @param facetPoint Central point of the 1-ring neighborhood
         * @return vertex A.
         */
        public Vector3d vertA(MeshPoint facetPoint) {
            if (facetPoint == tri.vertex1) {
                return tri.vertex1.getPosition();
            } else if (facetPoint == tri.vertex2) {
                return tri.vertex2.getPosition();
            } else if (facetPoint == tri.vertex3) {
                return tri.vertex3.getPosition();
            }
            return null;
        }

        /**
         * Returns vertex B. 
         * 
         * @param facetPoint Central point of the 1-ring neighborhood
         * @return vertex B.
         */
        public Vector3d vertB(MeshPoint facetPoint) {
            if (facetPoint == tri.vertex1) {
                return tri.vertex3.getPosition();
            } else if (facetPoint == tri.vertex2) {
                return tri.vertex1.getPosition();
            } else if (facetPoint == tri.vertex3) {
                return tri.vertex2.getPosition();
            }
            return null;
        }

        /**
         * Returns vertex C. 
         * 
         * @param facetPoint Central point of the 1-ring neighborhood
         * @return vertex C.
         */
        public Vector3d vertC(MeshPoint facetPoint) {
            if (facetPoint == tri.vertex1) {
                return tri.vertex2.getPosition();
            } else if (facetPoint == tri.vertex2) {
                return tri.vertex3.getPosition();
            } else if (facetPoint == tri.vertex3) {
                return tri.vertex1.getPosition();
            }
            return null;
        }

        /**
         * Returns cached angle in the vertex A. 
         * 
         * @param facetPoint Central point of the 1-ring neighborhood
         * @return Alpha angle.
         */
        public double alpha(MeshPoint facetPoint) {
            if (facetPoint == tri.vertex1) {
                return v1Angle;
            } else if (facetPoint == tri.vertex2) {
                return v2Angle;
            } else if (facetPoint == tri.vertex3) {
                return v3Angle;
            }
            return Double.NaN;
        }

        /**
         * Returns cached angle in the vertex B.
         * 
         * @param facetPoint Central point of the 1-ring neighborhood
         * @return Beta angle.
         */
        public double beta(MeshPoint facetPoint) {
            if (facetPoint == tri.vertex1) {
                return v3Angle;
            } else if (facetPoint == tri.vertex2) {
                return v1Angle;
            } else if (facetPoint == tri.vertex3) {
                return v2Angle;
            }
            return Double.NaN;
        }

        /**
         * Returns cached angle in the vertex C. 
         * 
         * @param facetPoint Central point of the 1-ring neighborhood
         * @return Gamma angle.
         */
        public double gamma(MeshPoint facetPoint) {
            if (facetPoint == tri.vertex1) {
                return v2Angle;
            } else if (facetPoint == tri.vertex2) {
                return v3Angle;
            } else if (facetPoint == tri.vertex3) {
                return v1Angle;
            }
            return Double.NaN;
        }

        /**
         * Returns cached cotangent of the angle in the vertex A. 
         * 
         * @param facetPoint Central point of the 1-ring neighborhood
         * @return Cotangent of the alpha angle.
         */
        public double alphaCotan(MeshPoint facetPoint) {
            if (facetPoint == tri.vertex1) {
                return v1AngleCotang;
            } else if (facetPoint == tri.vertex2) {
                return v2AngleCotang;
            } else if (facetPoint == tri.vertex3) {
                return v3AngleCotang;
            }
            return Double.NaN;
        }

        /**
         * Returns cached cotangent of the angle in the vertex B.
         * 
         * @param facetPoint Central point of the 1-ring neighborhood
         * @return Cotangent of the beta angle.
         */
        public double betaCotan(MeshPoint facetPoint) {
            if (facetPoint == tri.vertex1) {
                return v3AngleCotang;
            } else if (facetPoint == tri.vertex2) {
                return v1AngleCotang;
            } else if (facetPoint == tri.vertex3) {
                return v2AngleCotang;
            }
            return Double.NaN;
        }

        /**
         * Returns cached cotangent of the angle in the vertex C. 
         * 
         * @param facetPoint Central point of the 1-ring neighborhood
         * @return Cotangent of the gamma angle.
         */
        public double gammaCotan(MeshPoint facetPoint) {
            if (facetPoint == tri.vertex1) {
                return v2AngleCotang;
            } else if (facetPoint == tri.vertex2) {
                return v3AngleCotang;
            } else if (facetPoint == tri.vertex3) {
                return v1AngleCotang;
            }
            return Double.NaN;
        }
        
        /**
         * Returns squared length of the edge AB (opposite to the vertex A). 
         * 
         * @param facetPoint Central point of the 1-ring neighborhood
         * @return Squared length of the edge AB.
         */
        public double lengthSquaredAB(MeshPoint facetPoint) {
            Vector3d v = null;
            if (facetPoint == tri.vertex1) {
                v = new Vector3d(facet.getVertex(tri.index3).getPosition());
                v.sub(facet.getVertex(tri.index1).getPosition());
            } else if (facetPoint == tri.vertex2) {
                v = new Vector3d(facet.getVertex(tri.index1).getPosition());
                v.sub(facet.getVertex(tri.index2).getPosition());
            } else if (facetPoint == tri.vertex3) {
                v = new Vector3d(facet.getVertex(tri.index2).getPosition());
                v.sub(facet.getVertex(tri.index3).getPosition());
            }
            return v.lengthSquared();
        }
        
        /**
         * Returns squared length of the edge AC (opposite to the vertex B). 
         * 
         * @param facetPoint Central point of the 1-ring neighborhood
         * @return Squared length of the edge AC.
         */
        public double lengthSquaredAC(MeshPoint facetPoint) {
            Vector3d v = null;
            if (facetPoint == tri.vertex1) {
                v = new Vector3d(facet.getVertex(tri.index2).getPosition());
                v.sub(facet.getVertex(tri.index1).getPosition());
            } else if (facetPoint == tri.vertex2) {
                v = new Vector3d(facet.getVertex(tri.index3).getPosition());
                v.sub(facet.getVertex(tri.index2).getPosition());
            } else if (facetPoint == tri.vertex3) {
                v = new Vector3d(facet.getVertex(tri.index1).getPosition());
                v.sub(facet.getVertex(tri.index3).getPosition());
            }
            return v.lengthSquared();
        }
        
        /**
         * Returns squared length of the edge BC (opposite to the vertex A). 
         * 
         * @param facetPoint Central point of the 1-ring neighborhood
         * @return Squared length of the edge BC.
         */
        public double lengthSquaredBC(MeshPoint facetPoint) {
            Vector3d v = null;
            if (facetPoint == tri.vertex1) {
                v = new Vector3d(facet.getVertex(tri.index3).getPosition());
                v.sub(facet.getVertex(tri.index2).getPosition());
                v.sub(facetPoint.getPosition());
            } else if (facetPoint == tri.vertex2) {
                v = new Vector3d(facet.getVertex(tri.index1).getPosition());
                v.sub(facet.getVertex(tri.index3).getPosition());
            } else if (facetPoint == tri.vertex3) {
                v = new Vector3d(facet.getVertex(tri.index2).getPosition());
                v.sub(facet.getVertex(tri.index1).getPosition());
            }
            return v.lengthSquared();
        }
    }

}
