package cz.fidentis.analyst.octree;

import cz.fidentis.analyst.mesh.core.MeshFacet;
import cz.fidentis.analyst.mesh.core.MeshFacetImpl;
import cz.fidentis.analyst.mesh.core.MeshPoint;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collections;
import java.util.LinkedList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import javax.vecmath.Point3d;

/**
 * {@code Octree} for storing vertices ({@code MeshPoint}s) of 
 * triangular meshes ({@code MeshFacet}s).
 * Multiple mesh facets can by stored in a single {@code Octree}. In this case, 
 * vertices that are shared across multiple facets (have the same 3D location) 
 * are shared in the same node of the {@code Octree}.
 * 
 * @author Enkh-Undral EnkhBayar
 */
public class Octree implements Serializable {
    
    private OctNode root;
    
    /**
     * distance of the smallest cell in {@code Octree}.
     * It is the smallest between dx, dy and dz
     */ 
    private Double minLen = Double.MAX_VALUE;

    /**
     * Constructor. 
     * 
     * If no mesh points (vertices) are provided, then an empty 
     * {@code Octree} is constructed (with the root node set to null).
     *
     * @param points A set of individual mesh points. 
     */
    public Octree(Set<MeshPoint> points) {
        if (points == null) {
            this.root = null;
            return;
        }
        MeshFacet newFacet = new MeshFacetImpl();
        for (MeshPoint point : points) {
            newFacet.addVertex(point);
        }
        buildTree(new LinkedList<>(Collections.singleton(newFacet)));
    }

    /**
     * Constructor. 
     * 
     * If no mesh points (vertices) are provided, then an empty 
     * {@code Octree} is constructed (with the root node set to null).
     *
     * @param facet Mesh facet
     */
    public Octree(MeshFacet facet) {
        this(new LinkedList<>(Collections.singleton(facet)));
    }
    
    /**
     * Constructor. 
     * 
     * If no mesh points (vertices) are provided, then an empty 
     * {@code Octree} is constructed (with the root node set to null).
     * If multiple mesh facets share the same vertex, then they are stored 
     * efficiently in the same node of the {@code Octree}.
     *
     * @param facets The list of mesh facets to be stored. Facets can share vertices.
     */
    public Octree(List<MeshFacet> facets) {
        if(facets == null ||  facets.isEmpty() || facets.get(0).getVertices().isEmpty() ){
            this.root = null;
            return;
        }
        buildTree(facets);
    }
    
    /**
     * Tree traversal - go to the "root" of the tree.
     * 
     * @return root node of the tree
     */
    public OctNode getRoot() {
        return root;
    }
    
    /**
     * @return distance of the smallest cell in {@code Octree}.
     */
    public Double getMinLen() {
        return minLen;
    }

    /** 
     * Recursively display the contents of the tree in a verbose format.
     * Individual nodes are represented in format (if the node does not hold a point)
     * [small boundary, large boundary]
     * and in format if it does
     * [small boundary, large boundary] point
     * 
     * @return representation of the tree
     */
    @Override
    public String toString() {
        return root.toString("");
    }

    /***********************************************************
     *  PRIVATE METHODS                                        *
     ***********************************************************/
    
    private Point3d[] updateBoundaries(Point3d small, Point3d large) {
        double[] coorSmall = {small.x, small.y, small.z};
        double[] coorLarge = {large.x, large.y, large.z};
        int maxIndex = 0;
        for (int i = 1; i <= 2; i++) {
            if (coorLarge[i] - coorSmall[i] > coorLarge[maxIndex] - coorSmall[maxIndex]) {
                maxIndex = i;
            }
        }
        double halfMaxDiff = (coorLarge[maxIndex] - coorSmall[maxIndex]) / 2;
        for (int i = 0; i <= 2; i++) {
            if (i == maxIndex) {
                continue;
            }
            double average = (coorSmall[i] + coorLarge[i]) / 2;
            coorSmall[i] = average - halfMaxDiff;
            coorLarge[i] = average + halfMaxDiff;
        }
        return new Point3d[] {
            new Point3d(coorSmall[0], coorSmall[1], coorSmall[2]), 
            new Point3d(coorLarge[0], coorLarge[1], coorLarge[2])
        };
    }

    private void updateMinLen(Point3d smallestPoint, Point3d largestPoint) {
        minLen = Double.min(minLen, largestPoint.x - smallestPoint.x);
        minLen = Double.min(minLen, largestPoint.y - smallestPoint.y);
        minLen = Double.min(minLen, largestPoint.z - smallestPoint.z);
    }
    
    private void buildTree(List<MeshFacet> facets) {
        HashMap<MeshPoint, AggregatedVertex> vertices = new HashMap();
        
        /*
         * Find bounding box and aggregate vertices
         * with the same 3D location.
         */
        Point3d smallestPoint = null;
        Point3d largestPoint = null;
        for (MeshFacet facet: facets) {
            int index = 0;
            for (MeshPoint p: facet.getVertices()) {
                if (vertices.containsKey(p)) {
                    vertices.get(p).facets.add(facet);
                    vertices.get(p).indices.add(index);
                } else {
                    vertices.put(p, new AggregatedVertex(facet, index));
                    Point3d point = p.getPosition();
                    if (smallestPoint == null) {
                        smallestPoint = p.getPosition();
                    } else {
                        smallestPoint = new Point3d(
                            Double.min(smallestPoint.x, point.x),
                            Double.min(smallestPoint.y, point.y),
                            Double.min(smallestPoint.z, point.z)
                        );
                    }
                    if (largestPoint == null) {
                        largestPoint = p.getPosition();
                    } else {
                        largestPoint = new Point3d(
                            Double.max(largestPoint.x, point.x),
                            Double.max(largestPoint.y, point.y),
                            Double.max(largestPoint.z, point.z)
                        );
                    }
                }
                index++;
            }
        }
        Point3d[] boundaries = updateBoundaries(smallestPoint, largestPoint);
        smallestPoint = boundaries[0];
        largestPoint = boundaries[1];
        updateMinLen(smallestPoint, largestPoint);
        root = buildTree(vertices, smallestPoint, largestPoint);
    }
    
    /**
     * Builds Octree.
     *
     * @param vertices List of aggregated sorted vertices
     * @param smallestPoint boundary coordinate with x, y, z lower or equal
     * than any other point in created OctNode
     * @param largestPoint boundary coordinate with x, y, z higher or equal
     * than any other point in created OctNode
     * @return new node of the Octree
     */
    private OctNode buildTree(HashMap<MeshPoint, AggregatedVertex> vertices, Point3d smallestPoint, Point3d largestPoint) {
        if (vertices.isEmpty()) {
            return new OctNode(smallestPoint, largestPoint);
        }
        
        if (vertices.size() == 1) {
            updateMinLen(smallestPoint, largestPoint);
            Map.Entry<MeshPoint, AggregatedVertex> entry = vertices.entrySet().stream().findFirst().get();
            return new OctNode(entry.getValue().facets, entry.getValue().indices, smallestPoint, largestPoint);
        }
        
        Point3d middlePoint = new Point3d(
                (smallestPoint.x + largestPoint.x) / 2,
                (smallestPoint.y + largestPoint.y) / 2,
                (smallestPoint.z + largestPoint.z) / 2);

        List<HashMap<MeshPoint, AggregatedVertex>> octants = splitSpace(vertices, middlePoint);
        double[] xCoors = {smallestPoint.x, middlePoint.x, largestPoint.x};
        double[] yCoors = {smallestPoint.y, middlePoint.y, largestPoint.y};
        double[] zCoors = {smallestPoint.z, middlePoint.z, largestPoint.z};
        List<OctNode> doneOctants = new ArrayList<>();
        for (int i = 0; i < 8; i++) {
            HashMap<MeshPoint, AggregatedVertex> octant = octants.get(i);
            Point3d newSmallestPoint = new Point3d(
                xCoors[i >> 2], yCoors[i >> 1 & 1], zCoors[i & 1]
            );
            Point3d newLargestPoint = new Point3d(
                xCoors[(i >> 2) + 1], yCoors[(i >> 1 & 1) + 1], zCoors[(i & 1) + 1]
            );
            doneOctants.add(buildTree(octant, newSmallestPoint, newLargestPoint));
        }
        OctNode node = new OctNode(doneOctants, smallestPoint, largestPoint);
        return node;
    }
    
    /**
     * Splits all points into 8 octants separated by 3 planes which all 
     * intersect in middlePoint
     *
     * @param vertices List of aggregated sorted vertices
     * @param middlePoint middle point of space
     * @return 8 octants containing all vertices in space
     */
    private List<HashMap<MeshPoint, AggregatedVertex>> splitSpace(HashMap<MeshPoint, AggregatedVertex> vertices, Point3d middlePoint) {
        List<HashMap<MeshPoint, AggregatedVertex>> octants = new ArrayList<>();
        for (int i = 0; i < 8; i++) {
            octants.add(new HashMap<>());
        }
        for (MeshPoint vertex : vertices.keySet()) {
            int octantIndex = 0;
            if (vertex.getPosition().x > middlePoint.x) {
                octantIndex += 4;
            }
            if (vertex.getPosition().y > middlePoint.y) {
                octantIndex += 2;
            }
            if (vertex.getPosition().z > middlePoint.z) {
                octantIndex += 1;
            }
            octants.get(octantIndex).put(vertex, vertices.get(vertex));
        }
        return octants;
    }

    /***********************************************************
    *  EMBEDDED CLASSES
    ************************************************************/   

    /**
     * Helper class used during the Octree creation to store mesh vertices
     * with the same 3D location.
     * 
     * @author Radek Oslejsek
     */
    private class AggregatedVertex {
        public final List<MeshFacet> facets = new ArrayList<>();
        public final List<Integer> indices = new ArrayList<>();
        
        AggregatedVertex(MeshFacet f, int i) {
            facets.add(f);
            indices.add(i);
        }
    }
}
