package cz.fidentis.analyst.visitors.octree;

import cz.fidentis.analyst.mesh.core.MeshFacet;
import cz.fidentis.analyst.mesh.core.MeshTriangle;
import cz.fidentis.analyst.octree.OctNode;
import cz.fidentis.analyst.octree.Octree;
import cz.fidentis.analyst.octree.OctreeVisitor;
import java.util.HashSet;
import java.util.Set;
import javax.vecmath.Point3d;
import javax.vecmath.Vector3d;

/**
 * This visitor throws a ray from origin point and finds all intersection of the
 * ray with facets.
 *
 * @author Enkh-Undral EnkhBayar
 */
public class RayIntersectionVisitor extends OctreeVisitor {

    private static final double EPS = 1e-10;

    // constant used in case of center of gravity direction changes
    private final double dotConstant = 0.7;

    private final Point3d originPoint;
    private Vector3d direction;
    private boolean isOriginalDirection = true;
    private double octreeMinLen;

    private Set<MeshTriangle> calculatedTriangles = new HashSet<MeshTriangle>();

    private RayIntersectionsData data;

    /**
     * Constructor, if centerOfGravity is not null, direction can be changed to
     * vector (centerOfGravity -> origin) in case the angle between them is too
     * large
     *
     * @param mainFacet facet holding the origin point. Must not be {@code null}
     * @param index index of the origin point in mainFacet
     * @param centerOfGravity center of gravity of mainFacet, which is average
     * of all positions on mainFacet
     */
    public RayIntersectionVisitor(MeshFacet mainFacet, int index, Point3d centerOfGravity) {
        if (mainFacet == null) {
            throw new IllegalArgumentException("meshFacet is null");
        }
        originPoint = mainFacet.getVertex(index).getPosition();
        direction = mainFacet.getVertex(index).getNormal();
        direction.normalize();
        if (centerOfGravity != null) {
            Vector3d movement = new Vector3d(originPoint);
            movement.sub(centerOfGravity);
            movement.normalize();
            if (direction.dot(movement) < dotConstant) {
                direction = movement;
            }
        }
        data = new RayIntersectionsData(originPoint);
    }

    /**
     * Constructor
     *
     * @param mainFacet facet holding the origin point. Must not be {@code null}
     * @param index index of the origin point in mainFacet
     */
    public RayIntersectionVisitor(MeshFacet mainFacet, int index) {
        this(mainFacet, index, null);
    }

    /**
     * Constructor for testing purposes
     *
     * @param originPoint origin point of the ray. Must not be {@code null}
     * @param direction direction of the ray. Must not be {@code null}
     */
    RayIntersectionVisitor(Point3d originPoint, Vector3d direction) {
        if (originPoint == null) {
            throw new IllegalArgumentException("originPoint is null");
        }
        if (direction == null) {
            throw new IllegalArgumentException("direction is null");
        }
        this.originPoint = originPoint;
        this.direction = direction;
        data = new RayIntersectionsData(originPoint);
    }

    @Override
    public void visitOctree(Octree octree) {
        if (octree == null) {
            throw new IllegalArgumentException("octree is null");
        }

        octreeMinLen = octree.getMinLen();
        calculatedTriangles.clear();

        rayOctreeTraversal(octree.getRoot(), originPoint);

        invertVector();
        calculatedTriangles.clear();

        rayOctreeTraversal(octree.getRoot(), originPoint);

        isOriginalDirection = false;
        calculatedTriangles.clear();
        invertVector();
    }

    public RayIntersectionsData getIntersectionsData() {
        return data;
    }

    /**
     * Inverts the directional vector
     */
    void invertVector() {
        direction.scale(-1);
        isOriginalDirection = !isOriginalDirection;
    }

    /**
     * check whether or not the given triangle is valid in relation to the
     * directional vector
     *
     * @param triangle triangle to be checked
     * @return boolean true if triangle is valid false if not
     */
    private boolean isTriangleValid(MeshTriangle triangle) {
        if (calculatedTriangles.contains(triangle)) {
            return false;
        }

        double dotProduct = triangle.computeNormal().dot(direction);
        if (isOriginalDirection) {
            // if we are throwing outwards of face we want the triangle normals to be in the same direction
            if (dotProduct >= -EPS) {
                return true;
            }
        } else {
            // if we are throwing inwards of face we want the triangle normals to be in the opposite direction
            if (dotProduct <= EPS) {
                return true;
            }
        }
        return false;
    }

    /**
     * Goes through the octree provided in param node and calculates all the
     * intersections from originPoint with direction ray to other meshes
     *
     * @param node to calculate the distances to. Must not be {@code null}
     * @param p point in this cube (node) and if node is internal node it is
     * also in the next cube. Must not be {@code null}.
     * @return the point in next cube. The ray passes through next cube.
     */
    private Point3d rayOctreeTraversal(OctNode node, Point3d p) {
        if (node == null) {
            throw new IllegalArgumentException("node is null");
        }
        if (p == null) {
            throw new IllegalArgumentException("p is null");
        }

        if (!node.isLeafNode()) {
            int index = getOctantIndex(p, node);
            OctNode child;
            // if ray starts outside of cube we need to get inside first
            if (index == -1) {
                p = rayCubeIntersection(node, false);
                if (p == null) {
                    return null;
                }
                p = getPointInNextCube(p, node);
                index = getOctantIndex(p, node);
            }
            while (index != -1) {
                child = node.getOctant(index);
                p = rayOctreeTraversal(child, p);
                index = getOctantIndex(p, node);
            }
            return p;
        }

        for (MeshTriangle triangle : node.getTriangles()) {
            if (!isTriangleValid(triangle)) {
                calculatedTriangles.add(triangle);
                continue;
            }
            calculatedTriangles.add(triangle);
            Point3d intersection = triangle.getRayIntersection(originPoint, direction);
            if (intersection != null) {
                data.addIntersection(triangle.getFacet(), intersection);
            }
        }

        Point3d point = rayCubeIntersection(node, true);
        if (point == null) {
            throw new RuntimeException("Didnt find intersection with bounding box");
        }
        return getPointInNextCube(point, node);
    }

    /**
     * Calculates the point in the next cube in the direction of the ray.
     *
     * @param p point on the side / edge of the cube Must not be {@code null}.
     * @param cube cube from which to calculate point in next cube is needed
     * Must not be {@code null}.
     * @return the point in next cube. The ray passes through next cube.
     */
    private Point3d getPointInNextCube(Point3d p, OctNode cube) {
        if (p == null) {
            throw new IllegalArgumentException("p is null");
        }
        if (cube == null) {
            throw new IllegalArgumentException("cube is null");
        }

        double[] resultCoor = {p.x, p.y, p.z};
        double[] vCoor = {direction.x, direction.y, direction.z};
        Point3d small = cube.getSmallBoundary();
        double[] smallCoor = {small.x, small.y, small.z};
        Point3d large = cube.getLargeBoundary();
        double[] largeCoor = {large.x, large.y, large.z};
        for (int i = 0; i < 3; i++) {
            int sign = 0;
            if (vCoor[i] < 0) {
                sign = -1;
            } else if (vCoor[i] > 0) {
                sign = 1;
            } else {
                continue;
            }
            // if point coor is on the side of the cube, move said coor in direction of vector
            if (Math.abs(resultCoor[i] - smallCoor[i]) < EPS) {
                resultCoor[i] += sign * (octreeMinLen / 2);
            } else if (Math.abs(resultCoor[i] - largeCoor[i]) < EPS) {
                resultCoor[i] += sign * (octreeMinLen / 2);
            }
        }
        Point3d resultPoint = new Point3d(resultCoor[0], resultCoor[1], resultCoor[2]);
        if (resultPoint == p) {
            throw new IllegalArgumentException("p is not on the edges of the cube");
        }
        return resultPoint;
    }

    /**
     * Calculates the intersection between ray and the cube sides
     *
     * @param cube cube with the sides. Must not be {@code null}.
     * @param gotInsideOnce boolean parameter that tells us if ray got inside
     * the most outer cube. If it did we care about the second point in the
     * direction of the ray instead of the first one. As the first one will be
     * negative and therefore not in direction of the ray.
     * @return returns the intersection in the direction of the ray
     */
    Point3d rayCubeIntersection(OctNode cube, boolean gotInsideOnce) {
        if (cube == null) {
            throw new IllegalArgumentException("cube is null");
        }
        if (!gotInsideOnce && isPointInCube(originPoint, cube)) {
            throw new IllegalArgumentException("originPoint is in cube with gotInsideOnce set to false");
        }

        Point3d smallBoundary = cube.getSmallBoundary();
        Point3d largeBoundary = cube.getLargeBoundary();
        Vector3d[] planeNormals = {new Vector3d(1, 0, 0), new Vector3d(0, 1, 0), new Vector3d(0, 0, 1)};
        double[] t = {Double.NEGATIVE_INFINITY, Double.POSITIVE_INFINITY};
        for (Vector3d planeNormal : planeNormals) {
            double np = planeNormal.dot(new Vector3d(originPoint));
            double vp = planeNormal.dot(direction);
            if (vp == 0) {
                continue;
            }
            double[] offset = {planeNormal.dot(new Vector3d(smallBoundary)),
                planeNormal.dot(new Vector3d(largeBoundary))};
            double[] tTmp = {(offset[0] - np) / vp, (offset[1] - np) / vp};
            if (tTmp[0] > tTmp[1]) {
                double tmp = tTmp[0];
                tTmp[0] = tTmp[1];
                tTmp[1] = tmp;
            }
            t[0] = Double.max(tTmp[0], t[0]);
            t[1] = Double.min(tTmp[1], t[1]);
        }
        if (t[0] > t[1] || (t[0] < 0 && t[1] < 0)) {
            return null;
        }
        Point3d[] point = {new Point3d(originPoint), new Point3d(originPoint)};
        Vector3d[] vector = {new Vector3d(direction), new Vector3d(direction)};
        for (int i = 0; i < 2; i++) {
            vector[i].scale(t[i]);
            point[i].add(vector[i]);
        }
        if (gotInsideOnce) {
            return point[1];
        }
        return point[0];
    }

    /**
     * checks if p is inside the cube
     *
     * @param p point to be checked. Must not be {@code null}.
     * @param cube cube with the bounding box. Must not be {@code null}.
     * @return True if p is inside the cube false otherwise
     */
    static boolean isPointInCube(Point3d p, OctNode cube) {
        if (p == null) {
            throw new IllegalArgumentException("p is null");
        }
        if (cube == null) {
            throw new IllegalArgumentException("cube is null");
        }
        Point3d smallBoundary = cube.getSmallBoundary();
        Point3d largeBoundary = cube.getLargeBoundary();
        if (p.x < smallBoundary.x) {
            return false;
        }
        if (p.y < smallBoundary.y) {
            return false;
        }
        if (p.z < smallBoundary.z) {
            return false;
        }
        if (p.x > largeBoundary.x) {
            return false;
        }
        if (p.y > largeBoundary.y) {
            return false;
        }
        return p.z <= largeBoundary.z;
    }

    /**
     * Calculates the index of the child which can house the point provided. if
     * the p is in any of dividing planes of the cube, lower index will be
     * given. Even if cube is leaf node and therefore has no children, this
     * function will return an index.
     *
     * @param p point, must not be {@code null}
     * @param cube cube, must not be {@code null}
     * @return index (0 - 7) of child in cube. -1 if p is not inside the cube
     */
    protected static int getOctantIndex(Point3d p, OctNode cube) {
        if (p == null) {
            throw new IllegalArgumentException("p is null");
        }
        if (cube == null) {
            throw new IllegalArgumentException("cube is null");
        }
        if (!isPointInCube(p, cube)) {
            return -1;
        }
        Point3d smallBoundary = cube.getSmallBoundary();
        Point3d largeBoundary = cube.getLargeBoundary();
        Point3d middlePoint = new Point3d(
                (smallBoundary.x + largeBoundary.x) / 2,
                (smallBoundary.y + largeBoundary.y) / 2,
                (smallBoundary.z + largeBoundary.z) / 2);
        int octantIndex = 0;
        if (p.x > middlePoint.x) {
            octantIndex += 4;
        }
        if (p.y > middlePoint.y) {
            octantIndex += 2;
        }
        if (p.z > middlePoint.z) {
            octantIndex += 1;
        }
        return octantIndex;
    }
}
