Commit fd35b451 authored by Radek Ošlejšek's avatar Radek Ošlejšek
Browse files

Merge branch 'kd-tree-knlogn' into 'master'

k-d tree construction complexity

See merge request grp-fidentis/analyst2!426
parents 57a17eb0 126dceff
Loading
Loading
Loading
Loading
+0 −1
Original line number Diff line number Diff line
@@ -124,7 +124,6 @@ public class MeshDistancePreciseGPU extends MeshDistanceVisitorImpl {
        TriangleKdNode farChild;

        while (currentNode != null) {
            //System.out.println(depth);
            if (depth % 3 == 0 && point.getPosition().x < currentNode.getSplit() ||
                depth % 3 == 1 && point.getPosition().y < currentNode.getSplit() ||
                depth % 3 == 2 && point.getPosition().z < currentNode.getSplit()) {
+13 −18
Original line number Diff line number Diff line
@@ -15,9 +15,7 @@ import javax.vecmath.Vector3d;
import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;

import static org.junit.jupiter.api.Assertions.*;

@@ -25,7 +23,7 @@ import static org.junit.jupiter.api.Assertions.*;
public class MeshDistanceGPUTest {
    private final float VERTEX_DELTA = 0.0001f;
    private final float SURFACE_DELTA = 0.05f;
    private static final int RANDOM_VERTEX_COUNT = 6000;
    private static final int RANDOM_VERTEX_COUNT = 1500;  // must be a multiple of 3 for tests working with triangles
    private static final int RANDOM_SEED = 1234;

    private static final String FIRST_FACE_PATH = "src/test/resources/cz/fidentis/analyst/basic-model-01.obj";
@@ -494,6 +492,9 @@ public class MeshDistanceGPUTest {
    @Disabled  // Disabled as it takes too long, use for development.
    @DisabledIf("OpenCLNotAvailable")
    public void PreciseRandomTest() {
        // WARNING: setting a high number of vertexes may cause recursion overflow.
        // This is because random triangles are created close to each other, often overlapping.
        // The tree then cannot separate them cleanly.
        MeshDistanceConfig configCPU = new MeshDistanceConfig(MeshDistanceConfig.Method.POINT_TO_TRIANGLE_NEAREST_NEIGHBORS, firstRandom, null, false, false);
        MeshDistanceNN CPUVisitor = (MeshDistanceNN) configCPU.getVisitor();
        secondRandom.accept(CPUVisitor);
@@ -502,11 +503,9 @@ public class MeshDistanceGPUTest {
        MeshDistancePreciseGPU GPUVisitor = (MeshDistancePreciseGPU) configGPU.getVisitor();
        secondRandom.accept(GPUVisitor);

        for (int i = 0; i < firstRandom.getNumberOfVertices(); i++) {
            double distanceCPU = CPUVisitor.getDistancesOfVisitedFacets().getFacetMeasurement(secondRandom).get(i).getDistance();
            double distanceGPU = GPUVisitor.getDistancesOfVisitedFacets().getFacetMeasurement(secondRandom).get(i).getDistance();
            assertTrue(distanceGPU <= distanceCPU + SURFACE_DELTA, String.format("CPU: %f GPU: %f", distanceCPU, distanceGPU));
        }
        distancesLess(CPUVisitor.getDistancesOfVisitedFacets().getFacetMeasurement(secondRandom),
                GPUVisitor.getDistancesOfVisitedFacets().getFacetMeasurement(secondRandom),
                SURFACE_DELTA);
    }

    @Test
@@ -521,11 +520,9 @@ public class MeshDistanceGPUTest {
        MeshDistancePreciseGPU GPUVisitor = (MeshDistancePreciseGPU) configGPU.getVisitor();
        secondFace.accept(GPUVisitor);

        for (int i = 0; i < firstFace.getNumberOfVertices(); i++) {
            double distanceCPU = CPUVisitor.getDistancesOfVisitedFacets().getFacetMeasurement(secondFace).get(i).getDistance();
            double distanceGPU = GPUVisitor.getDistancesOfVisitedFacets().getFacetMeasurement(secondFace).get(i).getDistance();
            assertTrue(distanceGPU <= distanceCPU + SURFACE_DELTA, String.format("CPU: %f GPU: %f", distanceCPU, distanceGPU));
        }
        distancesLess(CPUVisitor.getDistancesOfVisitedFacets().getFacetMeasurement(secondFace),
                GPUVisitor.getDistancesOfVisitedFacets().getFacetMeasurement(secondFace),
                SURFACE_DELTA);
    }

    @Test
@@ -540,10 +537,8 @@ public class MeshDistanceGPUTest {
        MeshDistancePreciseGPU GPUVisitor = (MeshDistancePreciseGPU) configGPU.getVisitor();
        secondFace.accept(GPUVisitor);

        for (int i = 0; i < firstFace.getNumberOfVertices(); i++) {
            double distanceCPU = CPUVisitor.getDistancesOfVisitedFacets().getFacetMeasurement(secondFace).get(i).getDistance();
            double distanceGPU = GPUVisitor.getDistancesOfVisitedFacets().getFacetMeasurement(secondFace).get(i).getDistance();
            assertTrue(distanceGPU <= distanceCPU + SURFACE_DELTA, String.format("CPU: %f GPU: %f", distanceCPU, distanceGPU));
        }
        distancesLess(CPUVisitor.getDistancesOfVisitedFacets().getFacetMeasurement(secondFace),
                GPUVisitor.getDistancesOfVisitedFacets().getFacetMeasurement(secondFace),
                SURFACE_DELTA);
    }
}
+142 −67
Original line number Diff line number Diff line
@@ -5,6 +5,7 @@ import cz.fidentis.analyst.data.kdtree.LeftBalancedKdNode;
import cz.fidentis.analyst.data.kdtree.LeftBalancedKdTree;
import cz.fidentis.analyst.data.mesh.MeshFacet;
import cz.fidentis.analyst.data.mesh.MeshPoint;
import org.apache.commons.lang3.tuple.ImmutablePair;

import javax.vecmath.Point3d;
import java.util.*;
@@ -22,7 +23,7 @@ public class LeftBalancedKdTreeImpl implements LeftBalancedKdTree {
    private List<LeftBalancedKdNode> underlyingData;

    /**
     * Constructor.
     * Constructor, loads facets and builds tree.
     *
     * @param facets MeshFacets from which the KdTree is built.
     */
@@ -32,10 +33,19 @@ public class LeftBalancedKdTreeImpl implements LeftBalancedKdTree {
            return;
        }
        underlyingData = new ArrayList<>();
        Set<Point3d> knownPoints = new HashSet<>();

        for (MeshFacet meshFacet : facets) {
            List<MeshPoint> meshPoints = meshFacet.getVertices();
            for (int j = 0; j < meshPoints.size(); j++) {
                underlyingData.add(new LeftBalancedKdNodeImpl(meshPoints.get(j), meshFacet, j));
                LeftBalancedKdNode node = new LeftBalancedKdNodeImpl(meshPoints.get(j), meshFacet, j);
                Point3d nodeLocation = node.getLocation();

                if (knownPoints.contains(nodeLocation)) {
                    continue;
                }
                underlyingData.add(node);
                knownPoints.add(nodeLocation);
            }
        }

@@ -112,53 +122,73 @@ public class LeftBalancedKdTreeImpl implements LeftBalancedKdTree {
    }

    private int build() {
        return buildRecursively(underlyingData.toArray(new LeftBalancedKdNode[0]), 0, 0);
        LeftBalancedKdNode[] sortedX = underlyingData.toArray(new LeftBalancedKdNode[0]);
        LeftBalancedKdNode[] sortedY = sortedX.clone();
        LeftBalancedKdNode[] sortedZ = sortedX.clone();

        Arrays.sort(sortedX, compXYZ());
        Arrays.sort(sortedY, compYZX());
        Arrays.sort(sortedZ, compZXY());

        return buildRecursively(sortedX, sortedY, sortedZ, 0, 0);
    }

    private int buildRecursively(LeftBalancedKdNode[] points, int kdTreeLocation, int depth) {
        if (points.length == 1) {
            kdTree[kdTreeLocation] = points[0];
            kdTree[kdTreeLocation].setKdTreeValues(kdTreeLocation, depth);
            return depth;
    private int buildRecursively(LeftBalancedKdNode[] sortedX, LeftBalancedKdNode[] sortedY, LeftBalancedKdNode[] sortedZ, int kdTreeLocation, int depth) {
        assert (sortedX.length == sortedY.length) && (sortedY.length == sortedZ.length);

        if (sortedX.length == 0) {
            return depth - 1;
        }

        if (points.length == 0) {
        if (sortedX.length == 1) {
            kdTree[kdTreeLocation] = sortedX[0];  // same points in all arrays, only ordered by different axis
            kdTree[kdTreeLocation].setKdTreeValues(kdTreeLocation, depth);
            return depth;
        }

        switch (depth % 3) {
            case 0: Arrays.sort(points, new CompX()); break;
            case 1: Arrays.sort(points, new CompY()); break;
            case 2: Arrays.sort(points, new CompZ()); break;
            default: break;
        int divideIndex = getDivider(sortedX.length);

        LeftBalancedKdNode divider = switch (depth % 3) {
            case 0 -> sortedX[divideIndex];
            case 1 -> sortedY[divideIndex];
            case 2 -> sortedZ[divideIndex];
            default -> null;
        };
        kdTree[kdTreeLocation] = divider;
        kdTree[kdTreeLocation].setKdTreeValues(kdTreeLocation, depth);

        ImmutablePair<LeftBalancedKdNode[], LeftBalancedKdNode[]> dividedX = partitionPoints(sortedX, divider, depth);
        ImmutablePair<LeftBalancedKdNode[], LeftBalancedKdNode[]> dividedY = partitionPoints(sortedY, divider, depth);
        ImmutablePair<LeftBalancedKdNode[], LeftBalancedKdNode[]> dividedZ = partitionPoints(sortedZ, divider, depth);

        int leftDepth = buildRecursively(dividedX.left, dividedY.left, dividedZ.left, 2 * kdTreeLocation + 1, depth + 1);
        int rightDepth = buildRecursively(dividedX.right, dividedY.right, dividedZ.right, 2 * kdTreeLocation + 2, depth + 1);

        return Math.max(leftDepth, rightDepth);
    }

        int maximum = largestPowerOfTwo(points.length);
        int remainder = points.length - (maximum - 1);
    /**
     * Calculates index where to split the array
     * as to make the resulting tree left-balanced.
     *
     * @param arrayLength Length of array stored in the tree.
     * @return Split index so that the resulting tree would be left-balanced.
     */
    private int getDivider(int arrayLength) {
        int maximum = largestPowerOfTwo(arrayLength);
        int remainder = arrayLength - (maximum - 1);

        // Bounds calculations
        // J. A. Baerentzen. (2003, Aug 25). On Left-balancing Binary Trees [Online].
        // Available: https://www2.imm.dtu.dk/pubdb/edoc/imm2535.pdf
        int leftAmount;
        int rightAmount;
        int divideIndex;
        if (remainder <= maximum / 2) {
            leftAmount = (maximum - 2)/2 + remainder;
            rightAmount = (maximum - 2)/2;
            divideIndex = (maximum - 2) / 2 + remainder;
        } else {
            leftAmount = (maximum - 2)/2 + maximum/2;
            rightAmount = (maximum - 2)/2 + remainder - maximum/2;
            divideIndex = (maximum - 2) / 2 + maximum / 2;
        }

        LeftBalancedKdNode[] left = Arrays.copyOfRange(points, 0, leftAmount);
        LeftBalancedKdNode[] right = Arrays.copyOfRange(points, leftAmount + 1, leftAmount + 1 + rightAmount);

        kdTree[kdTreeLocation] = points[leftAmount];
        kdTree[kdTreeLocation].setKdTreeValues(kdTreeLocation, depth);

        int leftDepth = buildRecursively(left, 2 * kdTreeLocation + 1, depth + 1);
        int rightDepth = buildRecursively(right, 2 * kdTreeLocation + 2, depth + 1);

        return Math.max(leftDepth, rightDepth);
        return divideIndex;
    }

    /**
@@ -176,49 +206,94 @@ public class LeftBalancedKdTreeImpl implements LeftBalancedKdTree {
        return 2 << (power - 1);
    }


    /**
     * Comparator for sorting nodes in the x-axis.
     * Partitions points into those lower than divider and higher than divider
     * based on their value at {@code depth} axis.
     *
     * @author Ľubomír Jurčišin
     * @param points Array of points to be partitioned.
     * @param divider Point around which to partition the points array.
     * @param depth Which axis to check for partition.
     *
     * @return Pair of two arrays: points lower than divider and points higher than divider.
     */
    static class CompX implements Comparator<LeftBalancedKdNode> {
        @Override
        public int compare(LeftBalancedKdNode a, LeftBalancedKdNode b) {
            if (a.getMeshPoint().getPosition().x == b.getMeshPoint().getPosition().x) {
                return 0;
    private ImmutablePair<LeftBalancedKdNode[], LeftBalancedKdNode[]> partitionPoints(LeftBalancedKdNode[] points, LeftBalancedKdNode divider, int depth) {
        ArrayList<LeftBalancedKdNode> lowerList = new ArrayList<>();
        ArrayList<LeftBalancedKdNode> upperList = new ArrayList<>();

        for (LeftBalancedKdNode p : points) {
            if (p == divider) {
                continue;
            }
            return a.getMeshPoint().getPosition().x < b.getMeshPoint().getPosition().x ? -1 : 1;
            switch (depth % 3) {
                case 0:
                    if (compXYZ().compare(p, divider) < 0) {
                        lowerList.add(p);
                    } else {
                        upperList.add(p);
                    }
                    break;
                case 1:
                    if (compYZX().compare(p, divider) < 0) {
                        lowerList.add(p);
                    } else {
                        upperList.add(p);
                    }
                    break;
                case 2:
                    if (compZXY().compare(p, divider) < 0) {
                        lowerList.add(p);
                    } else {
                        upperList.add(p);
                    }
                    break;
                default: assert false;
            }
        }
        LeftBalancedKdNode[] lower = lowerList.toArray(new LeftBalancedKdNode[0]);
        LeftBalancedKdNode[] upper = upperList.toArray(new LeftBalancedKdNode[0]);

    /**
     * Comparator for sorting nodes in the y-axis.
     *
     * @author Ľubomír Jurčišin
     */
    static class CompY implements Comparator<LeftBalancedKdNode> {
        @Override
        public int compare(LeftBalancedKdNode a, LeftBalancedKdNode b) {
            if (a.getMeshPoint().getPosition().y == b.getMeshPoint().getPosition().y) {
                return 0;
        return new ImmutablePair<>(lower, upper);
    }
            return a.getMeshPoint().getPosition().y < b.getMeshPoint().getPosition().y ? -1 : 1;

    private Comparator<LeftBalancedKdNode> compXYZ() {
        return (a, b) -> {
            if (Double.compare(a.getMeshPoint().getX(), b.getMeshPoint().getX()) != 0) {
                return Double.compare(a.getMeshPoint().getX(), b.getMeshPoint().getX());
            }

            if (Double.compare(a.getMeshPoint().getY(), b.getMeshPoint().getY()) != 0) {
                return Double.compare(a.getMeshPoint().getY(), b.getMeshPoint().getY());
            }

    /**
     * Comparator for sorting nodes in the z-axis.
     *
     * @author Ľubomír Jurčišin
     */
    static class CompZ implements Comparator<LeftBalancedKdNode> {
        @Override
        public int compare(LeftBalancedKdNode a, LeftBalancedKdNode b) {
            if (a.getMeshPoint().getPosition().z == b.getMeshPoint().getPosition().z) {
                return 0;
            return Double.compare(a.getMeshPoint().getZ(), b.getMeshPoint().getZ());
        };
    }

    private Comparator<LeftBalancedKdNode> compYZX() {
        return (a, b) -> {
            if (Double.compare(a.getMeshPoint().getY(), b.getMeshPoint().getY()) != 0) {
                return Double.compare(a.getMeshPoint().getY(), b.getMeshPoint().getY());
            }

            if (Double.compare(a.getMeshPoint().getZ(), b.getMeshPoint().getZ()) != 0) {
                return Double.compare(a.getMeshPoint().getZ(), b.getMeshPoint().getZ());
            }

            return Double.compare(a.getMeshPoint().getX(), b.getMeshPoint().getX());
        };
    }
            return a.getMeshPoint().getPosition().z < b.getMeshPoint().getPosition().z ? -1 : 1;

    private Comparator<LeftBalancedKdNode> compZXY() {
        return (a, b) -> {
            if (Double.compare(a.getMeshPoint().getZ(), b.getMeshPoint().getZ()) != 0) {
                return Double.compare(a.getMeshPoint().getZ(), b.getMeshPoint().getZ());
            }

            if (Double.compare(a.getMeshPoint().getX(), b.getMeshPoint().getX()) != 0) {
                return Double.compare(a.getMeshPoint().getX(), b.getMeshPoint().getX());
            }

            return Double.compare(a.getMeshPoint().getY(), b.getMeshPoint().getY());
        };
    }
}
+8 −2
Original line number Diff line number Diff line
@@ -11,13 +11,13 @@ import java.util.*;
import static org.junit.jupiter.api.Assertions.assertTrue;

public class LeftBalancedKdTreeBuildTest {
    private static final Random random = new Random();
    private static final int VERTICES = 1500;
    private static final int SEED = 1234;
    private static final Random random = new Random(SEED);
    private static final List<MeshFacet> facets = new ArrayList<>();

    private record KdTreePoint(int index, int depth) {}


    @BeforeAll
    public static void setUp() {
        MeshFacet facet = MeshFactory.createEmptyMeshFacet();
@@ -34,6 +34,12 @@ public class LeftBalancedKdTreeBuildTest {
        facets.add(facet);
    }

    @Test
    public void noEmptySpace() {
        LeftBalancedKdTree kdTree = LeftBalancedKdTree.create(facets);
        kdTree.getLocationsArray();
    }


    @Test
    public void KdTreeBuildTest() {