package cz.fidentis.analyst.kdtree;

import cz.fidentis.analyst.mesh.core.MeshFacet;
import cz.fidentis.analyst.mesh.core.MeshFacetImpl;
import cz.fidentis.analyst.mesh.core.MeshPoint;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.LinkedList;
import java.util.List;
import java.util.Queue;
import java.util.Set;
import java.util.SortedMap;
import java.util.TreeMap;
import javax.vecmath.Vector3d;


/**
 * KD-tree for storing vertices (MeshPoints) of triangular meshes (MeshFacets).
 * Multiple mesh facets can by stored in a single kd-tree. In this case, 
 * vertices that are shared across multiple facets (have the same 3D location) 
 * are shared in the same node of the kd-tree.
 *
 * @author Maria Kocurekova
 */
public class KdTreeImpl {
    private KdNode root;
    private List<MeshFacet> facets;
    
    /**
     * Constructor.
     *
     * @param points A set of individual mesh points. 
     * If no mesh points are provided, then an empty 
     * KD-tree is constructed (with the root node set to null).
     */
    public KdTreeImpl(Set<MeshPoint> points) {
        if(points == null) {
            this.root = null;
            this.facets = null;
            return;
        }
        MeshFacet newFacet = new MeshFacetImpl();
        for(MeshPoint point : points) {
            newFacet.addVertex(point);
        }
        this.facets = new LinkedList<>(Collections.singleton(newFacet));
        buildTree();
    }

    /**
     * Constructor. If no mesh points (vertices) are provided, then an empty 
     * KD-tree is constructed (with the root node set to null).
     *
     * @param facet Mesh facet
     */
    public KdTreeImpl(MeshFacet facet) {
        this(new LinkedList<>(Collections.singleton(facet)));
    }

    /**
     * Constructor. If no mesh points (vertices) are provided, then an empty 
     * KD-tree is constructed (with the root node set to null).
     * If multiple mesh factes share the same vertex, then they are stored 
     * efficiently in the same node of the KD-tree.
     *
     * @param facets The list of mesh facets to be stored. Facets can share vertices.
     */
    public KdTreeImpl(List<MeshFacet> facets) {
        if(facets == null ||  facets.isEmpty() || facets.get(0).getVertices().isEmpty() ){
            this.root = null;
            this.facets = null;
            return;
        }
        this.facets = facets;
        buildTree();        
    }
    
    /**
     * Finds the closest node for given mesh point.
     *
     * @param p Mesh point to be searched
     * @return the closest node of the kd-tree or null
     */
    public KdNode closestNode(MeshPoint p) {
        if (p == null) {
            return null;
        }
        return closestNode(p.getPosition());
    }

    /**
     * Finds the closest node for given 3D location.
     *
     * @param pointPos 3D location
     * @return the closest node of the kd-tree or null
     */
    public KdNode closestNode(Vector3d pointPos) {
        if (pointPos == null || this.root == null) {
            return null;
        }
        
        /*
         * First, find the closest node
         */
        double minDistance = Double.MAX_VALUE;
        KdNode nearNode = root;
        KdNode searchedNode = root;

        while (searchedNode != null) {
            Vector3d nodePos = searchedNode.get3dLocation();
            double dist = MeshPoint.distance(pointPos, nodePos);
            if(dist < minDistance){
                nearNode = searchedNode;
                minDistance = dist;
            }

            if (firstIsLessThanSecond(pointPos, nodePos, searchedNode.getDepth())) {
                searchedNode = searchedNode.getLesser();
            }else{
                searchedNode = searchedNode.getGreater();
            }
        }

        /*
         * Second, search for vertex that could be potentially closer than
         * the nearest vertex already found
         */
        double distOnAxis;
        Queue<KdNode> queue = new LinkedList<>();
        queue.add(root);

        while (!queue.isEmpty()) {
            
            if (minDistance == 0) { // nothing can be closer
                break;
            }

            searchedNode = queue.poll();
            Vector3d nodePos = searchedNode.get3dLocation();

            double dist = MeshPoint.distance(pointPos, nodePos);
            if (dist < minDistance) {
                nearNode = searchedNode;
                minDistance = dist;
            }

            distOnAxis = minDistanceIntersection(nodePos, pointPos, searchedNode.getDepth());

            if (distOnAxis > minDistance) {
                if (firstIsLessThanSecond(pointPos, nodePos, searchedNode.getDepth())) {
                    if (searchedNode.getLesser() != null) {
                        queue.add(searchedNode.getLesser());
                    }
                } else {
                    if (searchedNode.getGreater() != null) {
                        queue.add(searchedNode.getGreater());
                    }
                }
            } else {
                if (searchedNode.getLesser() != null) {
                    queue.add(searchedNode.getLesser());
                }
                if (searchedNode.getGreater() != null) {
                    queue.add(searchedNode.getGreater());
                }
            }
        }

        return nearNode;
    }

    /**
     * Checks if the kd-tree includes a node with 3D location corresponsding 
     * to the given mesh point.
     * 
     * @param p Point whose location is searched
     * @return true if there is node covering the same 3D location, false otherwise.
     * @throws NullPointerException if the input parameter is null or has no position set
     */
    public boolean containsPoint(MeshPoint p){
        KdNode node = closestNode(p);
        if (node == null) {
            return false;
        }
        return p.getPosition().equals(node.get3dLocation());
    }

        
    @Override
    public String toString() {
        String ret = "";
        Queue<KdNode> queue = new LinkedList<>();
        queue.add(root);

        while (!queue.isEmpty()) {
            KdNode node = queue.poll();
            if (node == null) {
                continue;
            }
            queue.add(node.getLesser());
            queue.add(node.getGreater());
            ret += node.toString();
        }
        
        return ret;
    }

    
    /***********************************************************
     *  PRIVATE METHODS
     ***********************************************************/  
    
    private void buildTree() {
        SortedMap<Vector3d, SortData> byX = new TreeMap<>(new ComparatorX());
        SortedMap<Vector3d, SortData> byY = new TreeMap<>(new ComparatorY());
        SortedMap<Vector3d, SortData> byZ = new TreeMap<>(new ComparatorZ());
        
        for (MeshFacet facet: facets) {
            int index = 0;
            for (MeshPoint p: facet.getVertices()) {
                Vector3d k = p.getPosition();
                
                if (byX.containsKey(k)) {
                    byX.get(k).facets.add(facet);
                    byX.get(k).indices.add(index);
                } else {
                    byX.put(k, new SortData(facet, index));
                }

                if (byY.containsKey(k)) {
                    byY.get(k).facets.add(facet);
                    byY.get(k).indices.add(index);
                } else {
                    byY.put(k, new SortData(facet, index));
                }

                if (byZ.containsKey(k)) {
                    byZ.get(k).facets.add(facet);
                    byZ.get(k).indices.add(index);
                } else {
                    byZ.put(k, new SortData(facet, index));
                }

                index++;
            }
        }
        
        root = buildTree(null, byX, byY, byZ, 0);
    }
    
    /**
     * Building k-d tree
     *
     * @param parent node
     * @param byX list of points sorted by x axis
     * @param byY list of points sorted by y axis
     * @param byZ list of points sorted by z axis
     * @param level representation of coordinates
     * @return new k-d node
     */
    private KdNode buildTree(
            KdNode parent, 
            SortedMap<Vector3d, SortData> byX, 
            SortedMap<Vector3d, SortData> byY, 
            SortedMap<Vector3d, SortData> byZ, 
            int level) {
        
        KdNode node = null;
        
        if (byX.size() > 0 && byY.size() > 0 && byZ.size() > 0) {
            Vector3d pivot;
            SortData data;
            int mid = (byX.size() / 2);

            SortedMap<Vector3d, SortData> leftX = new TreeMap<>(new ComparatorX());
            SortedMap<Vector3d, SortData> leftY = new TreeMap<>(new ComparatorY());
            SortedMap<Vector3d, SortData> leftZ = new TreeMap<>(new ComparatorZ());

            //split lists in half, set middle element as new KdNode to be returned
            //first list to be split based on level, rest so that they contain points in first list
            //but also keep the ordering by their axis
            switch (level % 3) {
                case 0:
                    pivot = findPivot(byX, mid);
                    data = byX.get(pivot);
                    node = new KdNode(data.facets.get(0), data.indices.get(0), level, parent);
                    for (int i = 1; i < data.facets.size(); i++) {
                        node.addFacet(data.facets.get(i), data.indices.get(i));
                    }
                    splitMaps(pivot, leftX, byX, leftY, byY, leftZ, byZ);
                    break;
                case 1:
                    pivot = findPivot(byY, mid);
                    data = byY.get(pivot);
                    node = new KdNode(data.facets.get(0), data.indices.get(0), level, parent);
                    for (int i = 1; i < data.facets.size(); i++) {
                        node.addFacet(data.facets.get(i), data.indices.get(i));
                    }
                    splitMaps(pivot, leftY, byY, leftX, byX, leftZ, byZ);
                    break;
                case 2:
                    pivot = findPivot(byZ, mid);
                    data = byZ.get(pivot);
                    node = new KdNode(data.facets.get(0), data.indices.get(0), level, parent);
                    for (int i = 1; i < data.facets.size(); i++) {
                        node.addFacet(data.facets.get(i), data.indices.get(i));
                    }
                    splitMaps(pivot, leftZ, byZ, leftY, byY, leftX, byX);
                    break;
                default:
                    return null;
            }

            //removes current middle node from each all lists, in case there were duplicates
            byX.remove(pivot);
            byY.remove(pivot);
            byZ.remove(pivot);
            leftX.remove(pivot);
            leftY.remove(pivot);
            leftZ.remove(pivot);
            
            //System.out.println(data);
            //System.out.println(node);
            //System.out.println(data.indices);
            //System.out.println(data.facets);

            node.setLesser(buildTree(node, leftX, leftY, leftZ, level + 1));
            node.setGreater(buildTree(node, byX, byY, byZ, level + 1));
        }
        
        return node;
    }
    
    private Vector3d findPivot(SortedMap<Vector3d, SortData> map, int index) {
        int i = 0;
        for (Vector3d key: map.keySet()) {
            if (i++ == index) {
                return key;
            }
        }
        return null;
    }

    private void splitMaps(
            Vector3d pivot, 
            SortedMap<Vector3d, SortData> mainEmpty,   SortedMap<Vector3d, SortData> mainOrig, 
            SortedMap<Vector3d, SortData> secondEmpty, SortedMap<Vector3d, SortData> secondOrig, 
            SortedMap<Vector3d, SortData> thirdEmpty,  SortedMap<Vector3d, SortData> thirdOrig) {
        
        /*
         * Split the mainList into two parts
         */
        SortedMap<Vector3d, SortData> sub = mainOrig.subMap(mainOrig.firstKey(), pivot);
        mainEmpty.putAll(sub);
        sub.clear(); // remove from the mainOrig (sub is backed by the mainOrig)
        
        /*
         * Split the second list 
         */
        secondEmpty.putAll(secondOrig);
        secondEmpty.keySet().retainAll(mainEmpty.keySet());
        secondOrig.keySet().removeAll(secondEmpty.keySet());
        
        /*
         * Split the third list 
         */
        thirdEmpty.putAll(thirdOrig);
        thirdEmpty.keySet().retainAll(mainEmpty.keySet());
        thirdOrig.keySet().removeAll(thirdEmpty.keySet());
    }
    
    private boolean firstIsLessThanSecond(Vector3d v1, Vector3d v2, int level){
        switch (level % 3) {
            case 0:
                return v1.x <= v2.x;
            case 1:
                return v1.y <= v2.y;
            case 2:
                return v1.z <= v2.z;
            default:
                break;
        }
        return false;
    }
    

    
    /**
     * Calculates distance between two points
     * (currently searched node and point to which we want to find nearest neighbor)
     * (based on axis)
     *
     */
    private double minDistanceIntersection(Vector3d nodePosition, Vector3d pointPosition, int level){
        switch (level % 3) {
            case 0:
                return Math.abs(nodePosition.x - pointPosition.x);
            case 1:
                return Math.abs(nodePosition.y - pointPosition.y);
            default:
                return Math.abs(nodePosition.z - pointPosition.z);
        }
    }
   
    
    /***********************************************************
    *  EMBEDDED CLASSES
    ************************************************************/   

    /**
     * Helper class used during the kd-tree creation to store data related 
     * to a single 3D location.
     * 
     * @author Radek Oslejsek
     */
    private class SortData {
        public final List<MeshFacet> facets = new ArrayList<>();
        public final List<Integer> indices = new ArrayList<>();
        
        SortData(MeshFacet f, int i) {
            facets.add(f);
            indices.add(i);
        }
    }
    
    /**
     * Comparator prioritizing the X coordinate.
     * @author Radek Oslejsek
     */
    private class ComparatorX implements Comparator<Vector3d> {
        @Override
        public int compare(Vector3d arg0, Vector3d arg1) {
            int diff = Double.compare(arg0.x, arg1.x);
            if (diff != 0) { 
                return diff; 
            }
            diff = Double.compare(arg0.y, arg1.y);
            if (diff != 0) { 
                return diff; 
            }
            return Double.compare(arg0.z, arg1.z);
        }    
    }
    
    /**
     * Comparator prioritizing the X coordinate.
     * @author Radek Oslejsek
     */
    private class ComparatorY implements Comparator<Vector3d> {
        @Override
        public int compare(Vector3d arg0, Vector3d arg1) {
            int diff = Double.compare(arg0.y, arg1.y);
            if (diff != 0) { 
                return diff; 
            }
            diff = Double.compare(arg0.x, arg1.x);
            if (diff != 0) { 
                return diff; 
            }
            return Double.compare(arg0.z, arg1.z);
        }    
    }
    
    /**
     * Comparator prioritizing the X coordinate.
     * @author Radek Oslejsek
     */
    private class ComparatorZ implements Comparator<Vector3d> {
        @Override
        public int compare(Vector3d arg0, Vector3d arg1) {
            int diff = Double.compare(arg0.z, arg1.z);
            if (diff != 0) { 
                return diff; 
            }
            diff = Double.compare(arg0.y, arg1.y);
            if (diff != 0) { 
                return diff; 
            }
            return Double.compare(arg0.x, arg1.x);
        }    
    }
}
