package cz.fidentis.analyst.mesh.core;

import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import javax.vecmath.Vector3d;
import cz.fidentis.analyst.mesh.MeshVisitor;
import java.util.Iterator;
import java.util.NoSuchElementException;

/**
 * Mash facet is a compact triangular mesh without duplicated vertices.
 *
 * @author Matej Lukes
 */
public class MeshFacetImpl implements MeshFacet {
    
    /**
     * Centers of circumcircle of all triangles computed on demand. 
     * These points represent the point of Voronoi area used for Delaunay 
     * triangulation, for instance.
     */
    private List<Vector3d> voronoiPoints;
    
    private List<MeshPoint> vertices = new ArrayList<>();
    
    private final CornerTable cornerTable;
    
    /**
     * Constructor of MeshFacet
     */
    public MeshFacetImpl() {
        cornerTable = new CornerTable();
    }

    /**
     * Copy constructor of MeshFacet
     *
     * @param facet copied MeshFacet
     */
    public MeshFacetImpl(MeshFacet facet) {
        vertices.addAll(facet.getVertices()); // encapsulation preserved - vertices MeshPoints are immutable)
        cornerTable = new CornerTable(facet.getCornerTable());
    }
    
    @Override
    public void accept(MeshVisitor visitor) {
        visitor.visitMeshFacet(this);
    }
    
    @Override
    public MeshPoint getVertex(int index) {
        return vertices.get(index);
    }

    @Override
    public void addVertex(MeshPoint point) {
        vertices.add(point);
    }

    @Override
    public int getNumberOfVertices() {
        return vertices.size();
    }

    @Override
    public List<MeshPoint> getVertices() {
        return Collections.unmodifiableList(vertices);
    }

    @Override
    public CornerTable getCornerTable() {
        return cornerTable;
    }
    
    @Override
    public boolean hasVertexNormals() {
        return !this.vertices.isEmpty() && this.vertices.get(0).getNormal() != null;
    }
    
    /**
     * REPLACE WITH BETTER IMPLEMENTATION
     * @throws RuntimeException if there are duplicate meth points in the mesh facet
     */
    @Override
    public synchronized void calculateVertexNormals() {
        Map<Vector3d, Vector3d> normalMap = new HashMap<>(); // key = mesh point, value = normal
        
        // init normals:
        for (MeshPoint point: vertices) { 
            if (normalMap.put(point.getPosition(), new Vector3d(0, 0, 0)) != null) {
                throw new RuntimeException("Duplicate mesh point in the MeshFacet: " + point);
            }
        }
        
        // calculate normals from corresponding triangles
        for (MeshTriangle t : this) { 
            Vector3d triangleNormal = 
                    (t.getPoint3().subtractPosition(t.getPoint1())).crossProduct(t.getPoint2().subtractPosition(t.getPoint1())).getPosition();
            normalMap.get(t.getPoint1().getPosition()).add(triangleNormal);
            normalMap.get(t.getPoint2().getPosition()).add(triangleNormal);
            normalMap.get(t.getPoint3().getPosition()).add(triangleNormal);
        }
        
        // normalize normals:
        for (Vector3d normal: normalMap.values()) { 
            normal.normalize();
        }
    }
    
    @Override
    public int getNumTriangles() {
        return cornerTable.getSize();
    }
    
    @Override
    public List<MeshTriangle> getTriangles() {
        List<MeshTriangle> ret = new ArrayList<>(getNumTriangles());
        for (MeshTriangle tri : this) {
            ret.add(tri);
        }
        return ret;
    }
    
    @Override
    public List<MeshTriangle> getAdjacentTriangles(int vertexIndex) {
        List<MeshTriangle> ret = new ArrayList<>();
            
        List<Integer> adjacentTrianglesI = cornerTable.getTriangleIndexesByVertexIndex(vertexIndex);
        for (Integer triI: adjacentTrianglesI) {
            List<Integer> triVerticesI = cornerTable.getIndexesOfVerticesByTriangleIndex(triI);
                MeshTriangle tri = new MeshTriangle(
                    this,
                    triVerticesI.get(0),
                    triVerticesI.get(1),
                    triVerticesI.get(2));
                ret.add(tri);
            }
        
        return ret;
    }
    
    @Override
    public Vector3d getClosestAdjacentPoint(Vector3d point, int vertexIndex) {
        double dist = Double.POSITIVE_INFINITY;
        Vector3d ret = null;
        
        for (MeshTriangle tri: this.getAdjacentTriangles(vertexIndex)) {
            Vector3d projection = tri.getClosestPoint(point);
            Vector3d aux = new Vector3d(projection);
            aux.sub(point);
            double d = aux.length();
            if (d < dist) {
                dist = d;
                ret = projection;
            }
        }
        
        return ret;
    }
    
    @Override
    public double curvatureDistance(Vector3d point, int vertexIndex) {
        double dist = Double.POSITIVE_INFINITY;
        
        for (MeshTriangle tri: this.getAdjacentTriangles(vertexIndex)) {
            Vector3d projection = tri.getClosestPoint(point);
            Vector3d aux = new Vector3d(projection);
            aux.sub(point);
            double d = aux.length();
            if (d < dist) {
                dist = d;
            }
        }
        
        return dist;
    }
    
    @Override
    public Iterator<MeshTriangle> iterator() {
        return new Iterator<MeshTriangle>() {
            private int index;
    
            /**
             * 
             * @param facet Mesh facet to iterate
             */
            @Override
            public boolean hasNext() {
                return index < cornerTable.getSize();
            }

            @Override
            public MeshTriangle next() {
                if (!hasNext()) {
                    throw new NoSuchElementException();
                }
                
                int i1 = cornerTable.getRow(index + 0).getVertexIndex();
                int i2 = cornerTable.getRow(index + 1).getVertexIndex();
                int i3 = cornerTable.getRow(index + 2).getVertexIndex();
        
                MeshTriangle tri = new MeshTriangle(MeshFacetImpl.this, i1, i2, i3);
        
                index += 3;        
                return tri;
            }    
        };
    }
    
    @Override
    public synchronized List<Vector3d> calculateVoronoiPoints() {
        if (voronoiPoints == null) {
            voronoiPoints = new ArrayList<>(getNumTriangles());
            for (MeshTriangle tri: this) {
                voronoiPoints.add(tri.getVoronoiPoint());
            }
        } 
        return Collections.unmodifiableList(voronoiPoints);
    }
    
    @Override
    public TriangleFan getOneRingNeighborhood(int vertexIndex) {
        if (vertexIndex < 0 || vertexIndex >= this.getNumberOfVertices()) {
            return null;
        }
        return new TriangleFan(this, vertexIndex);
    }
    
    @Override
    public boolean simplify() {
        // aggregate duplicates into the map, remember old positions:
        Map<Vector3d, MeshPoint> mapPoints = new HashMap<>();
        Map<Vector3d, List<Integer>> mapOrigPositions = new HashMap<>();
        for (int i = 0; i < this.getNumberOfVertices(); i++) {
            Vector3d v = this.getVertex(i).getPosition();
            Vector3d n = this.getVertex(i).getNormal();
            Vector3d t = this.getVertex(i).getTexCoord();
            if (!mapPoints.containsKey(v)) {
                mapPoints.put(v, new MeshPointImpl(v, n, t));
                mapOrigPositions.put(v, new ArrayList<>());
            } else if (n != null) {
                mapPoints.put(v, mapPoints.get(v).addNormal(n));
            }
            mapOrigPositions.get(v).add(i);
        }
        
        if (mapPoints.size() == getNumberOfVertices()) {
            return false;
        }
        
        // create shrinked list of vertices:
        List<MeshPoint> newVertices = new ArrayList<>(mapPoints.size());
        Map<Integer, Integer> mapOrigNew = new HashMap<>();
        for (Vector3d v : mapPoints.keySet()) {
            MeshPoint p = mapPoints.get(v);
            if (p.getNormal() != null) {
                p.getNormal().normalize();
            }
            newVertices.add(p);
            
            for (Integer pos: mapOrigPositions.get(v)) {
                mapOrigNew.put(pos, newVertices.size()-1);
            }
        }
       
        // update corner table:
        for (int i = 0; i < this.cornerTable.getSize(); i++) {
            int origIndex = cornerTable.getRow(i).getVertexIndex();
            CornerTableRow newRow = new CornerTableRow(cornerTable.getRow(i));
            newRow.setVertexIndex(mapOrigNew.get(origIndex));
            cornerTable.replaceRow(i, newRow);
        }
        
        this.vertices = newVertices;
        return true;
    }
    
    
}

