package cz.fidentis.analyst.visitors.octree;

import cz.fidentis.analyst.mesh.MeshVisitor;
import cz.fidentis.analyst.mesh.core.MeshFacet;
import cz.fidentis.analyst.mesh.core.MeshPoint;
import cz.fidentis.analyst.mesh.core.MeshTriangle;
import cz.fidentis.analyst.octree.OctNode;
import cz.fidentis.analyst.octree.Octree;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import javax.vecmath.Point3d;
import javax.vecmath.Vector3d;

/**
 * This visitor goes through all the notes in mainFacets and for each point
 * it calculates the distance to the other meshes (octrees)
 * by throwing a ray from that point
 * 
 * @author Enkh-Undral EnkhBayar
 */
public class OctreeArrayIntersectionVisitor extends MeshVisitor {
    /** 
     * main mesh facets from which the intersections are calculated
     */
    private final Set<MeshFacet> mainFacets;
    
    /** 
     * helper set for calculations, 
     * used to cache which triangles were already calculated
     */
    private Set<MeshTriangle> calculatedTriangles = new HashSet<>();
    
    /**
     * calculated intersections.
     * In format {@code <main mainFacet, <index, <mainFacet, intersection>>>}
     * main mainFacet - main mainFacet saved in mainFacets which was used for intersection calculations
     * index          - index of the starting point in main mainFacet, 
     *                  said point can be accessed by mainFacet.get(index)
     * mainFacet      - mainFacet holding the intersection point
     * intersection   - intersection of ray coming from starting point and mainFacet
     */
    private Map<MeshFacet, Map<Integer, Map<MeshFacet, Point3d>>> intersections = new HashMap<>();
    
    /**
     * Constructor
     *
     * @param mainFacet the main Mesh mainFacet from which to calculate the distances.
     * Must not be {@code null}
     * @throws IllegalArgumentException if some parameter is wrong
     */
    public OctreeArrayIntersectionVisitor(MeshFacet mainFacet) {
        this(new HashSet<>(Collections.singleton(mainFacet)));
        if (mainFacet == null) {
            throw new IllegalArgumentException("mainFacet");
        }
    }
    
    /**
     * Constructor
     *
     * @param mainFacets the main Mesh mainFacet from which to calculate the distances. 
     *                   Must not be {@code null}
     * @throws IllegalArgumentException if some parameter is wrong
     */
    public OctreeArrayIntersectionVisitor(Set<MeshFacet> mainFacets) {
        if (mainFacets == null || mainFacets.isEmpty() || 
                (mainFacets.size() == 1 && mainFacets.contains(null))) {
            throw new IllegalArgumentException("mainFacets");
        }
        this.mainFacets = mainFacets;
    }
    
    /**
     * @return intersections from point in mainFacets to other mesh facets
     */
    public Map<MeshFacet, Map<Integer, Map<MeshFacet, Point3d>>> getIntersections() {
        Map<MeshFacet, Map<Integer, Map<MeshFacet, Point3d>>> unmodifiableMap = new HashMap<>();
        for (var entry : intersections.entrySet()) {
            MeshFacet mainFacet = entry.getKey();
            Map<Integer, Map<MeshFacet, Point3d>> unmodifiableInnerMap = new HashMap<>();
            for (var innerEntry : entry.getValue().entrySet()) {
                unmodifiableInnerMap.put(innerEntry.getKey(), Collections.unmodifiableMap(innerEntry.getValue()));
            }
            unmodifiableMap.put(mainFacet, Collections.unmodifiableMap(unmodifiableInnerMap));
        }
        return Collections.unmodifiableMap(unmodifiableMap);
    }
    
    @Override
    public void visitMeshFacet(MeshFacet facet) {
        if (facet == null) {
            throw new IllegalArgumentException("facet is null");
        }
        Octree octree = new Octree(facet);
        for (MeshFacet mainFacet : mainFacets) {
            int i = 0;
            for (MeshPoint meshPoint : mainFacet.getVertices()) {
                Vector3d vector = meshPoint.getNormal();
                Point3d point = meshPoint.getPosition();
                calculatedTriangles.clear();
                calculateDistanceInPoint(mainFacet, i, octree.getRoot(), point, vector, octree.getMinLen());
                vector.scale(-1);
                calculatedTriangles.clear();
                calculateDistanceInPoint(mainFacet, i, octree.getRoot(), point, vector, octree.getMinLen());
                calculatedTriangles.clear();
                i++;
            }
        }
    }

    /**
     * updates the intersections. If there are already is an intersection with 
     * the facet, the one with lower distance is kept and other discarded.
     * 
     * @param mainFacet mainFacet holding the starting point.
     *                  has to be in {@code this.mainFacets}
     *                  Must not be {@code null}.
     * @param startPointIndex index of the point in mainFacet from which the 
     *                        intersection was calculated.
     *                        mainFacet has to hold a point with startPointIndex
     *                        Must not be {@code null)
     * @param facet mainFacet in the node with the intersection.
              Must not be {@code null)
     * @param intersection intersection of mainFacet and ray from start.
                     Must not be {@code null).
     */
    private void updateIntersections(MeshFacet mainFacet, int startPointIndex, MeshFacet facet, Point3d intersection) {
        if (mainFacet == null) {
            throw new IllegalArgumentException("start is null");
        }
        if (facet == null) {
            throw new IllegalArgumentException("facet is null");
        }
        if (intersection == null) {
            throw new IllegalArgumentException("intersection is null");
        }
        Map<Integer, Map<MeshFacet, Point3d>> startingPointMap = new HashMap<>();
        Map<MeshFacet, Point3d> intersectionsMap = new HashMap<>();
        if (intersections.containsKey(mainFacet)) {
            startingPointMap = intersections.get(mainFacet);
            if (startingPointMap.containsKey(startPointIndex)) {
                intersectionsMap = startingPointMap.get(startPointIndex);
                if (intersectionsMap.containsKey(facet)) {
                    Point3d oldIntersection = intersectionsMap.get(facet);
                    Point3d startingPoint = mainFacet.getVertex(startPointIndex).getPosition();
                    if (startingPoint.distance(oldIntersection) <= startingPoint.distance(intersection)) {
                        return;
                    }
                }
            }
        }
        intersectionsMap.put(facet, intersection);
        startingPointMap.put(startPointIndex, intersectionsMap);
        intersections.put(mainFacet, startingPointMap);
    }
    
    /**
     * Goes through the octree provided in param node and calculates all the 
     * distances from that node to other meshes
     * 
     * @param mainFacet mainFacet holding the starting point.
     *                  has to be in {@code this.mainFacets}
     *                  Must not be {@code null}
     * @param startPointIndex index of the point in mainFacet from which the 
     *                        intersection is calculated.
     *                        mainFacet has to hold a point with startPointIndex
     *                        Must not be {@code null}
     * @param node to calculate the distances to.
     *             Must not be {@code null}
     * @param p point in this cube (node) and if node is internal node 
     *          it is also in the next cube. 
     *          Must not be {@code null}.
     * @param v vector from starting point in main node which defines the ray
     *          Must not be {@code null} nor (0, 0, 0).
     * @param minLen the smallest length in any cube (node) in node. 
     *        Must not be equal to 0
     * @return the point in next cube through which the ray passes
     */
    private Point3d calculateDistanceInPoint(MeshFacet mainFacet, int origPointIndex, OctNode node, Point3d p, Vector3d v, double minLen) {
        if (mainFacet == null) {
            throw new IllegalArgumentException("meshFacet is null");
        }
        if (node == null) {
            throw new IllegalArgumentException("node is null");
        }
        if (p == null) {
            throw new IllegalArgumentException("p is null");
        }
        if (v == null) {
            throw new IllegalArgumentException("v is null");
        }
        if (v == new Vector3d()) {
            throw new IllegalArgumentException("v is (0, 0, 0)");
        }
        if (minLen == 0) {
            throw new IllegalArgumentException("minLen is equal to 0");
        }
        Point3d origPoint = mainFacet.getVertex(origPointIndex).getPosition();

        if (!node.isLeafNode()) {
            int index = getOctantIndex(p, node);
            OctNode child;
            if (index == -1) {
                p = rayCubeIntersection(origPoint, v, node, false);
                p = getPointInNextCube(p, v, node, minLen);
                index = getOctantIndex(p, node);
            }
            while (index != -1) {
                child = node.getOctant(index);
                p = calculateDistanceInPoint(mainFacet, origPointIndex, child, p, v, minLen);
                index = getOctantIndex(p, node);
            }
            return p;
        }
        
        for (Map.Entry<MeshFacet, Integer> entry : node.getFacets().entrySet()) {
            MeshFacet facet = entry.getKey();
            Integer index = entry.getValue();
            for (MeshTriangle triangle : facet.getAdjacentTriangles(index)) {
                if (calculatedTriangles.contains(triangle)) {
                    continue;
                }
                calculatedTriangles.add(triangle);
                Point3d intersection = triangle.getRayIntersection(origPoint, v);
                if (intersection != null) {
                    updateIntersections(
                            mainFacet,
                            origPointIndex,
                            facet,
                            intersection
                    );
                }
            }
        }
        Point3d point = rayCubeIntersection(origPoint, v, node, true);
        if (point == null) {
            throw new RuntimeException("Didnt find intersection with bounding box");
        }
        return getPointInNextCube(point, v, node, minLen);
    }
    
    /**
     * Calculates the point in the next cube. 
     * 
     * @param p point on the side / edge of the cube
     *          Must not be {@code null}.
     * @param v vector of the ray
     *          Must not be {@code null} nor (0, 0, 0).
     * @param cube cube from which to calculate point in next cube is needed
     *             Must not be {@code null}.
     * @param minLen the smallest length in any cube (node) in node. 
     *        Must not be equal to 0
     * @return the point in next cube through which the ray passes
     */
    private Point3d getPointInNextCube(Point3d p, Vector3d v, OctNode cube, double minLen) {
        if (p == null) {
            throw new IllegalArgumentException("p is null");
        }
        if (v == null) {
            throw new IllegalArgumentException("v is null");
        }
        if (v == new Vector3d()) {
            throw new IllegalArgumentException("v is (0, 0, 0)");
        }
        if (cube == null) {
            throw new IllegalArgumentException("cube is null");
        }
        if (minLen == 0) {
            throw new IllegalArgumentException("minLen is equal to 0");
        }
        double[] resultCoor = {p.x, p.y, p.z};
        double[] vCoor = {v.x, v.y, v.z};
        Point3d small = cube.getSmallBoundary();
        double[] smallCoor = {small.x, small.y, small.z};
        Point3d large = cube.getLargeBoundary();
        double[] largeCoor = {large.x, large.y, large.z};
        for (int i = 0; i < 3; i++) {
            int sign;
            if (vCoor[i] < 0) {
                sign = -1;
            } else if (vCoor[i] > 0) {
                sign = 1;
            } else {
                continue;
            }
            if (resultCoor[i] == smallCoor[i]) {
                resultCoor[i] += sign * (minLen / 2);
            } else if (resultCoor[i] == largeCoor[i]) {
                resultCoor[i] += sign * (minLen / 2);
            }
        }
        Point3d resultPoint = new Point3d(resultCoor[0], resultCoor[1], resultCoor[2]);
        if (resultPoint == p) {
            throw new IllegalArgumentException("p is not on the edges of the cube");
        }
        return resultPoint;
    }
    
    /**
     * Calculates the intersection between ray and the cube sides
     * 
     * @param p starting point of the ray.
     *          Must not be {@code null}.
     * @param v vector of the ray
     *          Must not be {@code null} nor (0, 0, 0).
     * @param cube cube with the sides
     *          Must not be {@code null}.
     * @param gotInsideOnce boolean paramater that tells us if ray got inside 
     *                      the most outer cube. If it did we care about the 
     *                      second point in the direction of the ray instead 
     *                      of the first one
     * @return returns the intersection in the direction of the ray
     */
    private Point3d rayCubeIntersection(Point3d p, Vector3d v, OctNode cube, boolean gotInsideOnce) {
        if (p == null) {
            throw new IllegalArgumentException("p is null");
        }
        if (v == null) {
            throw new IllegalArgumentException("v is null");
        }
        if (v == new Vector3d()) {
            throw new IllegalArgumentException("v is (0, 0, 0)");
        }
        if (cube == null) {
            throw new IllegalArgumentException("cube is null");
        }
        Point3d smallBoundary = cube.getSmallBoundary();
        Point3d largeBoundary = cube.getLargeBoundary();
        Vector3d[] planeNormals = {new Vector3d(1, 0, 0), new Vector3d(0, 1, 0), new Vector3d(0, 0, 1)};
        double[] t = {-Double.MAX_VALUE, Double.MAX_VALUE};
        for (Vector3d planeNormal : planeNormals) {
            double np = planeNormal.dot(new Vector3d(p));
            double vp = planeNormal.dot(v);
            if (vp == 0) {
                continue;
            }
            double[] offset = {planeNormal.dot(new Vector3d(smallBoundary)),
                               planeNormal.dot(new Vector3d(largeBoundary))};
            double[] tTmp = {(offset[0] - np) / vp, (offset[1] - np) / vp};
            if (tTmp[0] > tTmp[1]) {
                double tmp = tTmp[0];
                tTmp[0] = tTmp[1];
                tTmp[1] = tmp;
            }
            t[0] = Double.max(tTmp[0], t[0]);
            t[1] = Double.min(tTmp[1], t[1]);
        }
        if (t[0] > t[1]) {
            return null;
        }
        Point3d point = new Point3d(p);
        Vector3d vector = new Vector3d(v);
        if (gotInsideOnce) {
            vector.scale(t[1]);
        } else {
            vector.scale(t[0]);
        }
        point.add(vector);
        return point;
    }
    
    /** 
     * checks if p is inside the cube
     * 
     * @param p point to be checked.
     *          Must not be {@code null}.
     * @param cube cube with the bounding box.
     *             Must not be {@code null}.
     * @return True if p is inside the cube false otherwise
     */
    private boolean isPointInCube(Point3d p, OctNode cube) {
        if (p == null) {
            throw new IllegalArgumentException("p is null");
        }
        if (cube == null) {
            throw new IllegalArgumentException("cube is null");
        }
        Point3d smallBoundary = cube.getSmallBoundary();
        Point3d largeBoundary = cube.getLargeBoundary();
        if (p.x < smallBoundary.x) {
            return false;
        }
        if (p.y < smallBoundary.y) {
            return false;
        }
        if (p.z < smallBoundary.z) {
            return false;
        }
        if (p.x > largeBoundary.x) {
            return false;
        }
        if (p.y > largeBoundary.y) {
            return false;
        }
        return p.z <= largeBoundary.z;
    }
    
    /**
     * Calculates the index of the child which can house the point provided
     * 
     * @param p point, must not be {@code null}
     * @param cube cube, must not be {@code null}
     * @return index (0 - 7) of child in cube. -1 if p is not inside the cube
     */
    private int getOctantIndex(Point3d p, OctNode cube) {
        if (p == null) {
            throw new IllegalArgumentException("p is null");
        }
        if (cube == null) {
            throw new IllegalArgumentException("cube is null");
        }
        if (!isPointInCube(p, cube)) {
            return -1;
        }
        Point3d smallBoundary = cube.getSmallBoundary();
        Point3d largeBoundary = cube.getLargeBoundary();
        Point3d middlePoint = new Point3d(
                (smallBoundary.x + largeBoundary.x) / 2,
                (smallBoundary.y + largeBoundary.y) / 2,
                (smallBoundary.z + largeBoundary.z) / 2);
        int octantIndex = 0;
        if (p.x > middlePoint.x) {
            octantIndex += 4;
        }
        if (p.y > middlePoint.y) {
            octantIndex += 2;
        }
        if (p.z > middlePoint.z) {
            octantIndex += 1;
        }
        return octantIndex;
    }
}
