package cz.fidentis.analyst.visitors.octree;

import cz.fidentis.analyst.mesh.core.CornerTableRow;
import cz.fidentis.analyst.mesh.core.MeshFacet;
import cz.fidentis.analyst.mesh.core.MeshFacetImpl;
import cz.fidentis.analyst.mesh.core.MeshPoint;
import cz.fidentis.analyst.mesh.core.MeshPointImpl;
import cz.fidentis.analyst.octree.Octree;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import javax.vecmath.Point3d;
import javax.vecmath.Vector3d;
import org.apache.commons.lang3.tuple.Pair;
import static org.junit.jupiter.api.Assertions.assertTrue;
import org.junit.jupiter.api.Test;

/**
 * @author Enkh-Undral EnkhBayar
 */
public class OctreeArrayIntersectionVisitorTest {
    
    private MeshFacet getTrivialFacet(double offset, double size) {
        MeshFacet facet = new MeshFacetImpl();
        facet.addVertex(new MeshPointImpl(new Point3d(0, 0, offset), new Vector3d(0, 0, 1), new Vector3d()));
        facet.addVertex(new MeshPointImpl(new Point3d(size, 0, offset), new Vector3d(0, 0, 1), new Vector3d()));
        facet.addVertex(new MeshPointImpl(new Point3d(0, size, offset), new Vector3d(0, 0, 1), new Vector3d()));

        facet.getCornerTable().addRow(new CornerTableRow(0, -1));
        facet.getCornerTable().addRow(new CornerTableRow(1, -1));
        facet.getCornerTable().addRow(new CornerTableRow(2, -1));

        return facet;
    }
    
    private void printIntersection(Map<MeshFacet, Map<Integer, Map<MeshFacet, Point3d>>> map) {
        StringBuilder builder = new StringBuilder();
        for (var entry : map.entrySet()) {
            MeshFacet mainFacet = entry.getKey();
            builder.append(mainFacet.toString()).append('\n');
            var startingPointMap = entry.getValue();
            int i = 0;
            for (var startingPointEntry : startingPointMap.entrySet()) {
                if (i == startingPointMap.size() - 1) {
                    builder.append("\\- ");
                } else {
                    builder.append("|- ");
                }
                int index = startingPointEntry.getKey();
                builder.append(index).append(": ").append(mainFacet.getVertex(index).getPosition());
                builder.append('\n');
                int j = 0;
                var intersectionMap = startingPointEntry.getValue();
                for (var intersectionEntry : intersectionMap.entrySet()) {
                    builder.append('\t');
                    if (j == intersectionMap.size() - 1) {
                        builder.append("\\- ");
                    } else {
                        builder.append("|- ");
                    }
                    builder.append(intersectionEntry.getKey());
                    builder.append(' ');
                    builder.append(intersectionEntry.getValue());
                }
                builder.append('\n');
                i++;
            }
        }
        System.out.println(builder);
    }
    
    private void checkMap(Map<MeshFacet, Map<Integer, Map<MeshFacet, Point3d>>> map, Map<MeshFacet, Map<Integer, Map<MeshFacet, Point3d>>> correct) {
        assertTrue(map != null, "map is null");
        Set<MeshFacet> mapKeySet = new HashSet<>(map.keySet());
        assertTrue(map.size() == correct.size());
        for (var entry : correct.entrySet()) {
            MeshFacet mainFacet = entry.getKey();
            assertTrue(map.containsKey(mainFacet));
            Map<Integer, Map<MeshFacet, Point3d>> startingPointMapCorrect = entry.getValue();
            Map<Integer, Map<MeshFacet, Point3d>> startingPointMap = map.get(mainFacet);
            assertTrue(startingPointMap.size() == startingPointMapCorrect.size());
            Set<Integer> startingPointMapKeySet = new HashSet(startingPointMap.keySet());
            for (var startingPointEntry : startingPointMapCorrect.entrySet()) {
                int index = startingPointEntry.getKey();
                assertTrue(startingPointMap.containsKey(index));
                var intersectionsMapCorrect = startingPointEntry.getValue();
                var intersectionsMap = startingPointMap.get(index);
                assertTrue(intersectionsMapCorrect.size() == intersectionsMap.size());
                Set<MeshFacet> intersectionMapKeySet = new HashSet(intersectionsMap.keySet());
                for (var intersectionEntry : intersectionsMapCorrect.entrySet()) {
                    MeshFacet facet = intersectionEntry.getKey();
                    assertTrue(intersectionsMap.containsKey(facet));
                    Point3d pointCorrect = intersectionEntry.getValue();
                    Point3d point = intersectionsMap.get(facet);
                    assertTrue(point.equals(pointCorrect));
                    intersectionMapKeySet.remove(facet);
                }
                assertTrue(intersectionMapKeySet.isEmpty());
                startingPointMapKeySet.remove(index);
            }
            assertTrue(startingPointMapKeySet.isEmpty());
            mapKeySet.remove(mainFacet);
        }
        assertTrue(mapKeySet.isEmpty());
    }
    
    private List<Integer> findIndices(List<Point3d> points, List<MeshPoint> vertices) {
        List<Integer> indices = new ArrayList<>();
        for (int i = 0; i < points.size(); i++) {
            indices.add(-1);
        }
        int i = 0;
        for (MeshPoint vertex : vertices) {
            for (int j = 0; j < points.size(); j++) {
                if (vertex.getPosition().equals(points.get(j))) {
                    indices.set(j, i);
                }
            }
            i++;
        }
        if (indices.contains(-1)) {
            return null;
        }
        return indices;
    }
    
    @Test
    public void zeroCollision() {
        MeshFacet main = getTrivialFacet(4, 0);
        MeshFacet facet = new MeshFacetImpl();
        facet.addVertex(new MeshPointImpl(new Point3d(1, 1, 1), new Vector3d(0, 0, 1), new Vector3d()));
        facet.addVertex(new MeshPointImpl(new Point3d(2, 1, 1), new Vector3d(0, 0, 1), new Vector3d()));
        facet.addVertex(new MeshPointImpl(new Point3d(1, 2, 1), new Vector3d(0, 0, 1), new Vector3d()));

        facet.getCornerTable().addRow(new CornerTableRow(0, -1));
        facet.getCornerTable().addRow(new CornerTableRow(1, -1));
        facet.getCornerTable().addRow(new CornerTableRow(2, -1));
        var visitor = new OctreeArrayIntersectionVisitor(main);
        visitor.visitMeshFacet(facet);
        var intersections = visitor.getIntersections();
        assertTrue(intersections.isEmpty());
    }

    @Test
    public void twoMeshesOneCollision() {
        MeshFacet main = getTrivialFacet(1, 2);
        MeshFacet second = getTrivialFacet(2, 1);
        
        List<Point3d> startingPoints = List.of(new Point3d(0, 0, 1));
        List<Integer> indices = findIndices(startingPoints, main.getVertices());
        assertTrue(indices != null);

        var visitor = new OctreeArrayIntersectionVisitor(main);
        visitor.visitMeshFacet(second);
        var intersections = visitor.getIntersections();
        Map<MeshFacet, Map<Integer, Map<MeshFacet, Point3d>>> correct = Map.of(
            main,
            Map.of(
                indices.get(0), 
                    Map.of(second, new Point3d(0, 0, 2))
            )
        );
        checkMap(intersections, correct);
    }
    
    @Test
    public void twoMeshesMultipleCollision() {
        MeshFacet main = getTrivialFacet(2, 1);
        MeshFacet second = getTrivialFacet(1, 2);
        var visitor = new OctreeArrayIntersectionVisitor(main);
        visitor.visitMeshFacet(second);
        var intersections = visitor.getIntersections();
        List<Point3d> startingPoints = List.of(new Point3d(0, 0, 2), new Point3d(0, 1, 2), new Point3d(1, 0, 2));
        List<Integer> indices = findIndices(startingPoints, main.getVertices());
        assertTrue(indices != null);
        Map<MeshFacet, Map<Integer, Map<MeshFacet, Point3d>>> correct = Map.of(
            main,
            Map.of(
                indices.get(0), Map.of(second, new Point3d(0, 0, 1)), 
                indices.get(1), Map.of(second, new Point3d(0, 1, 1)),
                indices.get(2), Map.of(second, new Point3d(1, 0, 1))
            )
        );
        checkMap(intersections, correct);
    }
}
