package cz.fidentis.analyst.kdtree;

import com.google.common.eventbus.EventBus;
import cz.fidentis.analyst.mesh.core.MeshFacet;
import cz.fidentis.analyst.mesh.core.MeshFacetImpl;
import cz.fidentis.analyst.mesh.core.MeshPoint;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.LinkedList;
import java.util.List;
import java.util.Queue;
import java.util.Set;
import java.util.SortedMap;
import java.util.TreeMap;
import javax.vecmath.Point3d;


/**
 * KD-tree for storing vertices ({@code MeshPoint}s) of triangular meshes ({@code MeshFacet}s).
 * Multiple mesh facets can by stored in a single kd-tree. In this case, 
 * vertices that are shared across multiple facets (have the same 3D location) 
 * are shared in the same node of the kd-tree.
 * <p>
 * This class implements the publish-subscribe notifications to changes.
 * </p>
 * <p>
 * Events fired by the class:
 * <ul>
 * <li>None because no modification method is available so far.</li>
 * </ul>
 * </p>
 *
 * @author Maria Kocurekova
 * @author Radek Oslejsek
 */
public class KdTree implements Serializable {
    
    private KdNode root;
    
    private final transient EventBus eventBus = new EventBus();
    
    /**
     * Constructor.
     *
     * @param points A set of individual mesh points. 
     * If no mesh points are provided, then an empty 
     * KD-tree is constructed (with the root node set to null).
     */
    public KdTree(Set<MeshPoint> points) {
        if (points == null) {
            this.root = null;
            return;
        }
        MeshFacet newFacet = new MeshFacetImpl();
        for (MeshPoint point : points) {
            newFacet.addVertex(point);
        }
        buildTree(new LinkedList<>(Collections.singleton(newFacet)));
    }

    /**
     * Constructor. If no mesh points (vertices) are provided, then an empty 
     * KD-tree is constructed (with the root node set to null).
     *
     * @param facet Mesh facet
     */
    public KdTree(MeshFacet facet) {
        this(new LinkedList<>(Collections.singleton(facet)));
    }

    /**
     * Constructor. If no mesh points (vertices) are provided, then an empty 
     * KD-tree is constructed (with the root node set to null).
     * If multiple mesh facets share the same vertex, then they are stored 
     * efficiently in the same node of the KD-tree.
     *
     * @param facets The list of mesh facets to be stored. Facets can share vertices.
     */
    public KdTree(List<MeshFacet> facets) {
        if(facets == null ||  facets.isEmpty() || facets.get(0).getVertices().isEmpty() ){
            this.root = null;
            return;
        }
        buildTree(facets);        
    }
    
    /**
     * Registers listeners (objects concerned in the kd-tree changes) to receive events.
     * If listener is {@code null}, no exception is thrown and no action is taken.
     * 
     * @param listener Listener concerned in the kd-tree changes.
     */
    public void registerListener(KdTreeListener listener) {
        eventBus.register(listener);
    }
    
    /**
     * Unregisters listeners from receiving events.
     * 
     * @param listener Registered listener
     */
    public void unregisterListener(KdTreeListener listener) {
        eventBus.unregister(listener);
    }

    /**
     * Tree traversal - go to the "root" of the tree.
     * 
     * @return root node of the tree
     */
    public KdNode getRoot() {
        return root;
    }

    @Override
    public String toString() {
        String ret = "";
        Queue<KdNode> queue = new LinkedList<>();
        queue.add(root);

        while (!queue.isEmpty()) {
            KdNode node = queue.poll();
            if (node == null) {
                continue;
            }
            queue.add(node.getLesser());
            queue.add(node.getGreater());
            ret += node.toString();
        }
        
        return ret;
    }
    
    /**
     * Return number of nodes in the k-d tree.
     * @return number of nodes in the k-d tree
     */
    public int getNumNodes() {
        int ret = 0;
        Queue<KdNode> queue = new LinkedList<>();
        queue.add(root);

        while (!queue.isEmpty()) {
            KdNode node = queue.poll();
            if (node == null) {
                continue;
            }
            queue.add(node.getLesser());
            queue.add(node.getGreater());
            ret++;
        }
        
        return ret;
    }
    
    /**
     * Return the length of the longest path.
     * @return Return the length of the longest path.
     */
    public int getDepth() {
        int depth = 0;
        Queue<KdNode> queue = new LinkedList<>();
        queue.add(root);

        while (!queue.isEmpty()) {
            KdNode node = queue.poll();
            if (node == null) {
                continue;
            }
            queue.add(node.getLesser());
            queue.add(node.getGreater());
            if (node.getDepth() > depth) {
                depth = node.getDepth();
            }
        }
        
        return depth;
    }
    
    /**
     * Visits this tree.
     * 
     * @param visitor Visitor
     */
    public void accept(KdTreeVisitor visitor) {
        visitor.visitKdTree(this);
    }
    

    /***********************************************************
     *  PRIVATE METHODS
     ***********************************************************/  
    
    private void buildTree(List<MeshFacet> facets) {
        SortedMap<Point3d, AggregatedVertex> vertices = new TreeMap<>(new ComparatorX());
        
        /*
         * Sort all vertices according to the X coordinate, aggregate vertices
         * with the same 3D location.
         */
        for (MeshFacet facet: facets) {
            int index = 0;
            for (MeshPoint p: facet.getVertices()) {
                Point3d k = p.getPosition();
                if (vertices.containsKey(k)) {
                    vertices.get(k).facets.add(facet);
                    vertices.get(k).indices.add(index);
                } else {
                    vertices.put(k, new AggregatedVertex(facet, index));
                }
                index++;
            }
        }
        
        root = buildTree(null, vertices, 0);
    }
    
    /**
     * Builds kd-tree.
     *
     * @param parent Parent node
     * @param vertices List of aggregates sorted vertices
     * @param level Tree depth that affects the splitting direction (x, y, or z)
     * @return new node of the kd-tree or null
     */
    private KdNode buildTree(KdNode parent, SortedMap<Point3d, AggregatedVertex> vertices, int level) {
        
        if (vertices.isEmpty()) {
            return null;
        }
        
        Point3d pivot = findPivot(vertices, level);
        AggregatedVertex data = vertices.get(pivot);
        
        KdNode node = new KdNode(data.facets, data.indices, level, parent);
            
        SortedMap<Point3d, AggregatedVertex> left = null;
        SortedMap<Point3d, AggregatedVertex> right = null;
            
        switch ((level + 1) % 3) {
            case 0:
                left = new TreeMap<>(new ComparatorX());
                right = new TreeMap<>(new ComparatorX());
                break;
            case 1:
                left = new TreeMap<>(new ComparatorY());
                right = new TreeMap<>(new ComparatorY());
                break;
            case 2:
                left = new TreeMap<>(new ComparatorZ());
                right = new TreeMap<>(new ComparatorZ());
                break;
            default:
                return null;
        }

        left.putAll(vertices.subMap(vertices.firstKey(), pivot));
        vertices.keySet().removeAll(left.keySet());
        vertices.remove(pivot);
        right.putAll(vertices);

        node.setLesser(buildTree(node, left, level + 1));
        node.setGreater(buildTree(node, right, level + 1));
        
        return node;
    }
    
    /**
     * Finds and returns the middle key or the first key with the same level-th coordinate.
     * 
     * @param map Map to be searched
     * @param level Tree depth that affects the choice of the coordinate (x, y, or z)
     * @return middle key or the first key with the same level-th coordinate
     */
    private Point3d findPivot(SortedMap<Point3d, AggregatedVertex> map, int level) {
        int mid = (map.size() / 2);
        int i = 0;
        
        Point3d fst = map.isEmpty() ? null : map.firstKey(); // First point with the same level-th coordinate as the 'key' point
        for (Point3d key: map.keySet()) {
            if (i++ == mid) {
                if (singleCoordinateEquals(key, fst, level)) {
                    return fst;
                }
                return key;
            }
            if (!singleCoordinateEquals(key, fst, level)) {
                fst = key;
            }
        }
        return null;
    }
    
    private boolean singleCoordinateEquals(Point3d v1, Point3d v2, int level) {
        switch (level % 3) {
            case 0:
                return v1.x == v2.x;
            case 1:
                return v1.y == v2.y;
            default:
                return v1.z == v2.z;
        }
    }
   
    
    /***********************************************************
    *  EMBEDDED CLASSES
    ************************************************************/   

    /**
     * Helper class used during the kd-tree creation to store mesh vertices
     * with the same 3D location.
     * 
     * @author Radek Oslejsek
     */
    private class AggregatedVertex {
        public final List<MeshFacet> facets = new ArrayList<>();
        public final List<Integer> indices = new ArrayList<>();
        
        AggregatedVertex(MeshFacet f, int i) {
            facets.add(f);
            indices.add(i);
        }
    }
    
    /**
     * Comparator prioritizing the X coordinate.
     * @author Radek Oslejsek
     */
    private class ComparatorX implements Comparator<Point3d> {
        @Override
        public int compare(Point3d arg0, Point3d arg1) {
            int diff = Double.compare(arg0.x, arg1.x);
            if (diff != 0) { 
                return diff; 
            }
            diff = Double.compare(arg0.y, arg1.y);
            if (diff != 0) { 
                return diff; 
            }
            return Double.compare(arg0.z, arg1.z);
        }    
    }
    
    /**
     * Comparator prioritizing the X coordinate.
     * @author Radek Oslejsek
     */
    private class ComparatorY implements Comparator<Point3d> {
        @Override
        public int compare(Point3d arg0, Point3d arg1) {
            int diff = Double.compare(arg0.y, arg1.y);
            if (diff != 0) { 
                return diff; 
            }
            diff = Double.compare(arg0.x, arg1.x);
            if (diff != 0) { 
                return diff; 
            }
            return Double.compare(arg0.z, arg1.z);
        }    
    }
    
    /**
     * Comparator prioritizing the X coordinate.
     * @author Radek Oslejsek
     */
    private class ComparatorZ implements Comparator<Point3d> {
        @Override
        public int compare(Point3d arg0, Point3d arg1) {
            int diff = Double.compare(arg0.z, arg1.z);
            if (diff != 0) { 
                return diff; 
            }
            diff = Double.compare(arg0.y, arg1.y);
            if (diff != 0) { 
                return diff; 
            }
            return Double.compare(arg0.x, arg1.x);
        }    
    }
}
