package cz.fidentis.analyst.kdtree;

import cz.fidentis.analyst.mesh.core.MeshFacet;
import cz.fidentis.analyst.mesh.core.MeshFacetImpl;
import cz.fidentis.analyst.mesh.core.MeshPoint;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.LinkedList;
import java.util.List;
import java.util.Queue;
import java.util.Set;
import java.util.SortedMap;
import java.util.TreeMap;
import javax.vecmath.Vector3d;


/**
 * KD-tree for storing vertices (MeshPoints) of triangular meshes (MeshFacets).
 * Multiple mesh facets can by stored in a single kd-tree. In this case, 
 * vertices that are shared across multiple facets (have the same 3D location) 
 * are shared in the same node of the kd-tree.
 *
 * @author Maria Kocurekova
 */
public class KdTreeImpl {
    private KdNode root;
    
    /**
     * Constructor.
     *
     * @param points A set of individual mesh points. 
     * If no mesh points are provided, then an empty 
     * KD-tree is constructed (with the root node set to null).
     */
    public KdTreeImpl(Set<MeshPoint> points) {
        if(points == null) {
            this.root = null;
            return;
        }
        MeshFacet newFacet = new MeshFacetImpl();
        for(MeshPoint point : points) {
            newFacet.addVertex(point);
        }
        buildTree(new LinkedList<>(Collections.singleton(newFacet)));
    }

    /**
     * Constructor. If no mesh points (vertices) are provided, then an empty 
     * KD-tree is constructed (with the root node set to null).
     *
     * @param facet Mesh facet
     */
    public KdTreeImpl(MeshFacet facet) {
        this(new LinkedList<>(Collections.singleton(facet)));
    }

    /**
     * Constructor. If no mesh points (vertices) are provided, then an empty 
     * KD-tree is constructed (with the root node set to null).
     * If multiple mesh factes share the same vertex, then they are stored 
     * efficiently in the same node of the KD-tree.
     *
     * @param facets The list of mesh facets to be stored. Facets can share vertices.
     */
    public KdTreeImpl(List<MeshFacet> facets) {
        if(facets == null ||  facets.isEmpty() || facets.get(0).getVertices().isEmpty() ){
            this.root = null;
            return;
        }
        buildTree(facets);        
    }
    
    /**
     * Finds the closest node for given mesh point.
     *
     * @param p Mesh point to be searched
     * @return the closest node of the kd-tree or null
     */
    public KdNode closestNode(MeshPoint p) {
        if (p == null) {
            return null;
        }
        return closestNode(p.getPosition());
    }

    /**
     * Finds the closest node for given 3D location.
     *
     * @param pointPos 3D location
     * @return the closest node of the kd-tree or null
     */
    public KdNode closestNode(Vector3d pointPos) {
        if (pointPos == null || this.root == null) {
            return null;
        }
        
        /*
         * First, find the closest node
         */
        double minDistance = Double.MAX_VALUE;
        KdNode nearNode = root;
        KdNode searchedNode = root;

        while (searchedNode != null) {
            Vector3d nodePos = searchedNode.get3dLocation();
            double dist = MeshPoint.distance(pointPos, nodePos);
            if(dist < minDistance){
                nearNode = searchedNode;
                minDistance = dist;
            }

            if (firstIsLessThanSecond(pointPos, nodePos, searchedNode.getDepth())) {
                searchedNode = searchedNode.getLesser();
            }else{
                searchedNode = searchedNode.getGreater();
            }
        }

        /*
         * Second, search for a vertex that could be potentially closer than
         * the nearest vertex already found
         */
        double distOnAxis;
        Queue<KdNode> queue = new LinkedList<>();
        queue.add(root);

        while (!queue.isEmpty()) {
            
            if (minDistance == 0) { // nothing can be closer
                break;
            }

            searchedNode = queue.poll();
            Vector3d nodePos = searchedNode.get3dLocation();

            double dist = MeshPoint.distance(pointPos, nodePos);
            if (dist < minDistance) {
                nearNode = searchedNode;
                minDistance = dist;
            }

            distOnAxis = minDistanceIntersection(nodePos, pointPos, searchedNode.getDepth());

            if (distOnAxis > minDistance) {
                if (firstIsLessThanSecond(pointPos, nodePos, searchedNode.getDepth())) {
                    if (searchedNode.getLesser() != null) {
                        queue.add(searchedNode.getLesser());
                    }
                } else {
                    if (searchedNode.getGreater() != null) {
                        queue.add(searchedNode.getGreater());
                    }
                }
            } else {
                if (searchedNode.getLesser() != null) {
                    queue.add(searchedNode.getLesser());
                }
                if (searchedNode.getGreater() != null) {
                    queue.add(searchedNode.getGreater());
                }
            }
        }

        return nearNode;
    }

    /**
     * Checks if the kd-tree includes a node with 3D location corresponsding 
     * to the given mesh point.
     * 
     * @param p Point whose location is searched
     * @return true if there is node covering the same 3D location, false otherwise.
     * @throws NullPointerException if the input parameter is null or has no position set
     */
    public boolean containsPoint(MeshPoint p){
        KdNode node = closestNode(p);
        if (node == null) {
            return false;
        }
        return p.getPosition().equals(node.get3dLocation());
    }

        
    @Override
    public String toString() {
        String ret = "";
        Queue<KdNode> queue = new LinkedList<>();
        queue.add(root);

        while (!queue.isEmpty()) {
            KdNode node = queue.poll();
            if (node == null) {
                continue;
            }
            queue.add(node.getLesser());
            queue.add(node.getGreater());
            ret += node.toString();
        }
        
        return ret;
    }

    
    /***********************************************************
     *  PRIVATE METHODS
     ***********************************************************/  
    
    private void buildTree(List<MeshFacet> facets) {
        SortedMap<Vector3d, AggregatedVertex> vertices = new TreeMap<>(new ComparatorX());
        
        /*
         * Sort all vertices according to the X coordinate, aggregate vertices
         * with the same 3D location.
         */
        for (MeshFacet facet: facets) {
            int index = 0;
            for (MeshPoint p: facet.getVertices()) {
                Vector3d k = p.getPosition();
                if (vertices.containsKey(k)) {
                    vertices.get(k).facets.add(facet);
                    vertices.get(k).indices.add(index);
                } else {
                    vertices.put(k, new AggregatedVertex(facet, index));
                }
                index++;
            }
        }
        
        root = buildTree(null, vertices, 0);
    }
    
    /**
     * Builds kd-tree.
     *
     * @param parent Parent node
     * @param vertices List of aggregates sorted vertices
     * @param level Tree depth that affects the splitting direction (x, y, or z)
     * @return new node of the kd-tree or null
     */
    private KdNode buildTree(KdNode parent, SortedMap<Vector3d, AggregatedVertex> vertices, int level) {
        
        if (vertices.isEmpty()) {
            return null;
        }
        
        Vector3d pivot = findPivot(vertices);
        AggregatedVertex data = vertices.get(pivot);
        
        KdNode node = new KdNode(data.facets, data.indices, level, parent);
            
        SortedMap<Vector3d, AggregatedVertex> left = null;
        SortedMap<Vector3d, AggregatedVertex> right = null;
            
        switch (level % 3) {
            case 0:
                left = new TreeMap<>(new ComparatorX());
                right = new TreeMap<>(new ComparatorX());
                break;
            case 1:
                left = new TreeMap<>(new ComparatorY());
                right = new TreeMap<>(new ComparatorY());
                break;
            case 2:
                left = new TreeMap<>(new ComparatorZ());
                right = new TreeMap<>(new ComparatorZ());
                break;
            default:
                return null;
        }

        left.putAll(vertices.subMap(vertices.firstKey(), pivot));
        vertices.keySet().removeAll(left.keySet());
        vertices.remove(pivot);
        right.putAll(vertices);

        node.setLesser(buildTree(node, left, level + 1));
        node.setGreater(buildTree(node, right, level + 1));
        
        return node;
    }
    
    /**
     * Finds and return the middle key
     * @param map Map to be searched
     * @return middle key
     */
    private Vector3d findPivot(SortedMap<Vector3d, AggregatedVertex> map) {
        int mid = (map.size() / 2);
        int i = 0;
        for (Vector3d key: map.keySet()) {
            if (i++ == mid) {
                return key;
            }
        }
        return null;
    }
    
    private boolean firstIsLessThanSecond(Vector3d v1, Vector3d v2, int level){
        switch (level % 3) {
            case 0:
                return v1.x <= v2.x;
            case 1:
                return v1.y <= v2.y;
            case 2:
                return v1.z <= v2.z;
            default:
                break;
        }
        return false;
    }
    
    /**
     * Calculates distance between two points
     * (currently searched node and point to which we want to find nearest neighbor)
     * (based on axis)
     *
     */
    private double minDistanceIntersection(Vector3d nodePosition, Vector3d pointPosition, int level){
        switch (level % 3) {
            case 0:
                return Math.abs(nodePosition.x - pointPosition.x);
            case 1:
                return Math.abs(nodePosition.y - pointPosition.y);
            default:
                return Math.abs(nodePosition.z - pointPosition.z);
        }
    }
   
    
    /***********************************************************
    *  EMBEDDED CLASSES
    ************************************************************/   

    /**
     * Helper class used during the kd-tree creation to store mesh vertices
     * with the same 3D location.
     * 
     * @author Radek Oslejsek
     */
    private class AggregatedVertex {
        public final List<MeshFacet> facets = new ArrayList<>();
        public final List<Integer> indices = new ArrayList<>();
        
        AggregatedVertex(MeshFacet f, int i) {
            facets.add(f);
            indices.add(i);
        }
    }
    
    /**
     * Comparator prioritizing the X coordinate.
     * @author Radek Oslejsek
     */
    private class ComparatorX implements Comparator<Vector3d> {
        @Override
        public int compare(Vector3d arg0, Vector3d arg1) {
            int diff = Double.compare(arg0.x, arg1.x);
            if (diff != 0) { 
                return diff; 
            }
            diff = Double.compare(arg0.y, arg1.y);
            if (diff != 0) { 
                return diff; 
            }
            return Double.compare(arg0.z, arg1.z);
        }    
    }
    
    /**
     * Comparator prioritizing the X coordinate.
     * @author Radek Oslejsek
     */
    private class ComparatorY implements Comparator<Vector3d> {
        @Override
        public int compare(Vector3d arg0, Vector3d arg1) {
            int diff = Double.compare(arg0.y, arg1.y);
            if (diff != 0) { 
                return diff; 
            }
            diff = Double.compare(arg0.x, arg1.x);
            if (diff != 0) { 
                return diff; 
            }
            return Double.compare(arg0.z, arg1.z);
        }    
    }
    
    /**
     * Comparator prioritizing the X coordinate.
     * @author Radek Oslejsek
     */
    private class ComparatorZ implements Comparator<Vector3d> {
        @Override
        public int compare(Vector3d arg0, Vector3d arg1) {
            int diff = Double.compare(arg0.z, arg1.z);
            if (diff != 0) { 
                return diff; 
            }
            diff = Double.compare(arg0.y, arg1.y);
            if (diff != 0) { 
                return diff; 
            }
            return Double.compare(arg0.x, arg1.x);
        }    
    }
}
