package cz.fidentis.analyst.mesh.core;

import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import javax.vecmath.Vector3d;
import cz.fidentis.analyst.mesh.MeshVisitor;
import java.util.Iterator;
import java.util.NoSuchElementException;

/**
 * MashFacet
 *
 * @author Matej Lukes
 */
public class MeshFacetImpl implements MeshFacet {
    
    private List<MeshPoint> vertices = new ArrayList<>();
    private CornerTable cornerTable;
    
    /**
     * Constructor of MeshFacet
     */
    public MeshFacetImpl() {
        cornerTable = new CornerTable();
    }

    /**
     * Copy constructor of MeshFacet
     *
     * @param facet copied MeshFacet
     */
    public MeshFacetImpl(MeshFacet facet) {
        vertices.addAll(facet.getVertices()); // encapsulation preserved - vertices MeshPoints are immutable)
        cornerTable = new CornerTable(facet.getCornerTable());
    }
    
    @Override
    public void accept(MeshVisitor visitor) {
        visitor.visit(this);
    }
    
    @Override
    public MeshPoint getVertex(int index) {
        return vertices.get(index);
    }

    @Override
    public void addVertex(MeshPoint point) {
        vertices.add(point);
    }

    @Override
    public int getNumberOfVertices() {
        return vertices.size();
    }

    @Override
    public List<MeshPoint> getVertices() {
        return Collections.unmodifiableList(vertices);
    }

    @Override
    public CornerTable getCornerTable() {
        return cornerTable;
    }
    
    @Override
    public boolean hasVertexNormals() {
        return !this.vertices.isEmpty() && this.vertices.get(0).getNormal() != null;
    }
    
    /**
     * REPLACE WITH BETTER IMPLEMENTATION
     * @throws RuntimeException if there are duplicate meth points in the mesh facet
     */
    @Override
    public void calculateVertexNormals() {
        Map<Vector3d, Vector3d> normalMap = new HashMap<>(); // key = mesh point, value = normal
        
        // init normals:
        for (MeshPoint point: vertices) { 
            if (normalMap.put(point.getPosition(), new Vector3d(0, 0, 0)) != null) {
                throw new RuntimeException("Duplicate mesh point in the MeshFacet: " + point);
            }
        }
        
        // calculate normals from corresponding triangles
        for (MeshTriangle t : this) { 
            Vector3d triangleNormal = 
                    (t.vertex3.subtractPosition(t.vertex1)).crossProduct(t.vertex2.subtractPosition(t.vertex1)).getPosition();
            normalMap.get(t.vertex1.getPosition()).add(triangleNormal);
            normalMap.get(t.vertex2.getPosition()).add(triangleNormal);
            normalMap.get(t.vertex3.getPosition()).add(triangleNormal);
        }
        
        // normalize normals:
        for (Vector3d normal: normalMap.values()) { 
            normal.normalize();
        }
    }
    
    @Override
    public int getNumTriangles() {
        return cornerTable.getSize();
    }
    
    @Override
    public List<MeshTriangle> getAdjacentTriangles(int vertexIndex) {
        List<MeshTriangle> ret = new ArrayList<>();
            
        List<Integer> adjacentTrianglesI = cornerTable.getTriangleIndexesByVertexIndex(vertexIndex);
        for (Integer triI: adjacentTrianglesI) {
            List<Integer> triVerticesI = cornerTable.getIndexesOfVerticesByTriangleIndex(triI);
            MeshTriangle tri = new MeshTriangle(
                    getVertex(triVerticesI.get(0)),
                    getVertex(triVerticesI.get(1)),
                    getVertex(triVerticesI.get(2)),
                    triVerticesI.get(0),
                    triVerticesI.get(1),
                    triVerticesI.get(2));
            ret.add(tri);
        }
        
        return ret;
    }
    
    @Override
    public Vector3d getClosestAdjacentPoint(Vector3d point, int vertexIndex) {
        double dist = Double.POSITIVE_INFINITY;
        Vector3d ret = null;
        
        for (MeshTriangle tri: this.getAdjacentTriangles(vertexIndex)) {
            Vector3d projection = tri.getClosestPoint(point);
            Vector3d aux = new Vector3d(projection);
            aux.sub(point);
            double d = aux.length();
            if (d < dist) {
                dist = d;
                ret = projection;
            }
        }
        
        return ret;
    }
    
    @Override
    public double curvatureDistance(Vector3d point, int vertexIndex) {
        double dist = Double.POSITIVE_INFINITY;
        
        for (MeshTriangle tri: this.getAdjacentTriangles(vertexIndex)) {
            Vector3d projection = tri.getClosestPoint(point);
            Vector3d aux = new Vector3d(projection);
            aux.sub(point);
            double d = aux.length();
            if (d < dist) {
                dist = d;
            }
        }
        
        return dist;
    }
    
    @Override
    public Iterator<MeshTriangle> iterator() {
        return new Iterator<MeshTriangle>() {
            private int index;
    
            /**
             * 
             * @param facet Mesh facet to iterate
             */
            @Override
            public boolean hasNext() {
                return index < cornerTable.getSize();
            }

            @Override
            public MeshTriangle next() {
                if (!hasNext()) {
                    throw new NoSuchElementException();
                }
                
                int i1 = cornerTable.getRow(index + 0).getVertexIndex();
                int i2 = cornerTable.getRow(index + 1).getVertexIndex();
                int i3 = cornerTable.getRow(index + 2).getVertexIndex();
        
                MeshTriangle tri = new MeshTriangle(vertices.get(i1), vertices.get(i2), vertices.get(i3), i1, i2, i3);
        
                index += 3;        
                return tri;
            }    
        };
    }
    
}

