package cz.fidentis.analyst.visitors.kdtree;

import cz.fidentis.analyst.kdtree.KdNode;
import cz.fidentis.analyst.kdtree.KdTree;
import cz.fidentis.analyst.kdtree.KdTreeVisitor;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.Queue;
import java.util.Set;
import javax.vecmath.Point3d;

/**
 * This visitor finds all nodes from the k-d tree that are the closest to the given 3D location.
 * <p>
 * The visitor returns the distance and closest nodes of all inspected k-d trees.
 * </p>
 * <p>
 * This visitor is thread-safe, i.e., a single instance of the visitor can be used 
 * to inspect multiple k-d trees simultaneously.
 * </p>
 * 
 * @author Daniel Schramm
 */
public class KdTreeClosestNode extends KdTreeVisitor {
    
    private final Point3d point3d;
    private double distance = Double.POSITIVE_INFINITY;
    private final Set<KdNode> closest = new HashSet<>();
    
    /**
     * Constructor.
     * 
     * @param point The 3D location for which the closest nodes are to be computed.
     * @throws IllegalArgumentException if some parameter is wrong
     */
    public KdTreeClosestNode(Point3d point) {
        if (point == null) {
            throw new IllegalArgumentException("point");
        }
        this.point3d = point;
    }

    /**
     * Returns set of the closest nodes of the visited KD tree.
     * 
     * @return closest nodes of the visited KD tree
     */
    public Set<KdNode> getClosestNodes() {
        return Collections.unmodifiableSet(closest);
    }
    
    /**
     * Returns any closest node or {@code null}
     * @return any closest node or {@code null}
     */
    public KdNode getAnyClosestNode() {
        return closest.isEmpty() ? null : (new ArrayList<>(closest)).get(0);
    }
    
    /**
     * Returns the minimal distance
     * @return the minimal distance
     */
    public double getDistance() {
        return distance;
    }
    
    /**
     * Returns the 3D point to which the closest points of inspected kd-tree has been searched.
     * @return the 3D point to which the closest points of inspected kd-tree has been searched.
     */
    public Point3d getReferencePoint() {
        return point3d;
    }

    @Override
    public void visitKdTree(KdTree kdTree) {
        final KdNode root = kdTree.getRoot();
        if (root == null) {
            return;
        }
        
        KdNode searchedNode = root;
        
        /*
         * Reduce the search space by exploring the area where the closest node
         * may theoretically be located.
         */
        while (searchedNode != null) {
            final Point3d nodePos = searchedNode.getLocation();
            final double dist = nodePos.distance(point3d);
            synchronized (this) {
                if (dist < distance) {
                    distance = dist;
                    closest.clear();
                }
            }

            if (firstIsLessThanSecond(point3d, nodePos, searchedNode.getDepth())) {
                searchedNode = searchedNode.getLesser();
            } else {
                searchedNode = searchedNode.getGreater();
            }
        }

        /*
         * Search for vertices that could be potentially closer than
         * the closest distance already found or at the same distance.
         */
        final Queue<KdNode> queue = new LinkedList<>();
        queue.add(root);

        while (!queue.isEmpty()) {
            searchedNode = queue.poll();
            final Point3d nodePos = searchedNode.getLocation();

            final double dist = nodePos.distance(point3d);
            synchronized (this) {
                if (dist < distance) {
                    distance = dist;
                    closest.clear();
                    closest.add(searchedNode);
                } else if (dist == distance) {
                    closest.add(searchedNode);
                }
            }
            
            final double distOnAxis = minDistanceIntersection(nodePos, point3d, searchedNode.getDepth());

            if (distOnAxis > distance) {
                if (firstIsLessThanSecond(point3d, nodePos, searchedNode.getDepth())) {
                    if (searchedNode.getLesser() != null) {
                        queue.add(searchedNode.getLesser());
                    }
                } else {
                    if (searchedNode.getGreater() != null) {
                        queue.add(searchedNode.getGreater());
                    }
                }
            } else {
                if (searchedNode.getLesser() != null) {
                    queue.add(searchedNode.getLesser());
                }
                if (searchedNode.getGreater() != null) {
                    queue.add(searchedNode.getGreater());
                }
            }
        }
    }

    protected boolean firstIsLessThanSecond(Point3d v1, Point3d v2, int level){
        switch (level % 3) {
            case 0:
                return v1.x <= v2.x;
            case 1:
                return v1.y <= v2.y;
            case 2:
                return v1.z <= v2.z;
            default:
                break;
        }
        return false;
    }
    
    /**
     * Calculates distance between two points
     * (currently searched node and point to which we want to find nearest neighbor)
     * (based on axis)
     *
     */
    protected double minDistanceIntersection(Point3d nodePosition, Point3d pointPosition, int level){
        switch (level % 3) {
            case 0:
                return Math.abs(nodePosition.x - pointPosition.x);
            case 1:
                return Math.abs(nodePosition.y - pointPosition.y);
            default:
                return Math.abs(nodePosition.z - pointPosition.z);
        }
    }
    
}
