From 7f002a6a55cc2d82c1eedabdfea190a89f79a7ea Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?David=20Proch=C3=A1zka?= <david@prochazka.dev> Date: Sun, 28 Feb 2021 19:39:40 +0100 Subject: [PATCH] FIX: simpified createRoot method, simplified computation of node distances, generalized findClosestItem --- src/mhtree/BuildTree.java | 222 +++++++++++++++++++------------------- 1 file changed, 110 insertions(+), 112 deletions(-) diff --git a/src/mhtree/BuildTree.java b/src/mhtree/BuildTree.java index b8b7c9f..812ba70 100644 --- a/src/mhtree/BuildTree.java +++ b/src/mhtree/BuildTree.java @@ -5,41 +5,36 @@ import messif.buckets.BucketDispatcher; import messif.buckets.BucketStorageException; import messif.objects.LocalAbstractObject; -import java.util.*; +import java.util.ArrayList; +import java.util.BitSet; +import java.util.List; import java.util.function.BiFunction; -import java.util.function.Function; import java.util.stream.Collectors; -import java.util.stream.IntStream; class BuildTree { - - private final int leafCapacity; private final int arity; - - private final Node[] nodes; - private final BitSet validNodeIndices; + private final int leafCapacity; + private final InsertType insertType; + private final ObjectToNodeDistance objectToNodeDistance; + private final NodeToNodeDistance nodeToNodeDistance; + private final BucketDispatcher bucketDispatcher; private final PrecomputedDistances objectDistances; private final float[][] nodeDistances; - private final InsertType insertType; - private final ObjectToNodeDistance objectToNodeDistance; - private final BucketDispatcher bucketDispatcher; - private final BiFunction<Float, Float, Float> distanceSelector; + private final Node[] nodes; + private final BitSet validNodeIndices; - private Node root; + private final Node root; BuildTree(List<LocalAbstractObject> objects, int leafCapacity, int arity, InsertType insertType, ObjectToNodeDistance objectToNodeDistance, NodeToNodeDistance nodeToNodeDistance, BucketDispatcher bucketDispatcher) throws BucketStorageException { - this.arity = arity; this.leafCapacity = leafCapacity; + this.arity = arity; this.insertType = insertType; this.objectToNodeDistance = objectToNodeDistance; + this.nodeToNodeDistance = nodeToNodeDistance; this.bucketDispatcher = bucketDispatcher; - // Set distance selector function for two hull objects - boolean selectNearest = nodeToNodeDistance == NodeToNodeDistance.NEAREST_HULL_OBJECTS; - distanceSelector = selectNearest ? Math::min : Math::max; - nodes = new Node[objects.size() / leafCapacity]; validNodeIndices = new BitSet(nodes.length); @@ -48,55 +43,38 @@ class BuildTree { objectDistances = new PrecomputedDistances(objects); nodeDistances = new float[nodes.length][nodes.length]; - // Every object is stored in the root - if (objectDistances.getObjectCount() < leafCapacity) { + // Every object is stored in the root in the case of small number of objects + if (objectDistances.getObjectCount() <= leafCapacity) { root = new LeafNode(objectDistances, bucketDispatcher.createBucket(), insertType, objectToNodeDistance); return; } createLeafNodes(); - - precomputeLeafNodeDistances(); - - buildTree(); + root = createRoot(); } public Node getRoot() { return root; } - private void buildTree() { + private Node createRoot() { while (validNodeIndices.cardinality() != 1) { BitSet notProcessedNodeIndices = (BitSet) validNodeIndices.clone(); while (!notProcessedNodeIndices.isEmpty()) { - if (notProcessedNodeIndices.cardinality() < arity) { - Set<Integer> nodeIndices = new HashSet<>(notProcessedNodeIndices.cardinality() - 1); - - int mainNodeIndex = notProcessedNodeIndices.nextSetBit(0); - notProcessedNodeIndices.stream().skip(1).forEach(nodeIndices::add); - - mergeNodes(mainNodeIndex, nodeIndices); - + if (notProcessedNodeIndices.cardinality() <= arity) { + mergeNodes(notProcessedNodeIndices.stream().boxed().collect(Collectors.toList())); break; } int furthestNodeIndex = getFurthestIndex(nodeDistances, notProcessedNodeIndices); notProcessedNodeIndices.clear(furthestNodeIndex); - Set<Integer> nnNodeIndices = new HashSet<>(arity - 1); - - for (int i = 0; i < arity - 1; i++) { - int index = objectDistances.minDistInArrayExceptIdx(nodeDistances[furthestNodeIndex], notProcessedNodeIndices, furthestNodeIndex); - notProcessedNodeIndices.clear(index); - nnNodeIndices.add(index); - } - - mergeNodes(furthestNodeIndex, nnNodeIndices); + mergeNodes(furthestNodeIndex, findClosestItems(this::findClosestNodeIndex, furthestNodeIndex, arity - 1, notProcessedNodeIndices)); } } - root = nodes[validNodeIndices.nextSetBit(0)]; + return nodes[validNodeIndices.nextSetBit(0)]; } private void createLeafNodes() throws BucketStorageException { @@ -105,128 +83,148 @@ class BuildTree { for (int nodeIndex = 0; !notProcessedObjectIndices.isEmpty(); nodeIndex++) { if (notProcessedObjectIndices.cardinality() < leafCapacity) { - for (int i = notProcessedObjectIndices.nextSetBit(0); i >= 0; i = notProcessedObjectIndices.nextSetBit(i + 1)) - addObjectToClosestNode(i); + for (int i = notProcessedObjectIndices.nextSetBit(0); i >= 0; i = notProcessedObjectIndices.nextSetBit(i + 1)) { + LocalAbstractObject object = objectDistances.getObject(i); + + nodes[getClosestNodeIndex(object)].addObject(object); + } + return; } - List<LocalAbstractObject> objects = new ArrayList<>(leafCapacity); + List<Integer> objectIndices = new ArrayList<>(leafCapacity); // Select a base object int furthestIndex = getFurthestIndex(objectDistances.getDistances(), notProcessedObjectIndices); notProcessedObjectIndices.clear(furthestIndex); - objects.add(objectDistances.getObject(furthestIndex)); + objectIndices.add(furthestIndex); // Select the rest of the objects up to the total of leafCapacity - objects.addAll(findClosestObjects(furthestIndex, leafCapacity - 1, notProcessedObjectIndices)); + objectIndices.addAll(findClosestItems(this::findClosestObjectIndex, furthestIndex, leafCapacity - 1, notProcessedObjectIndices)); + + List<LocalAbstractObject> objects = objectIndices.stream().map(objectDistances::getObject).collect(Collectors.toList()); nodes[nodeIndex] = new LeafNode(objectDistances.getSubset(objects), bucketDispatcher.createBucket(), insertType, objectToNodeDistance); } } - private void addObjectToClosestNode(int objectIndex) throws BucketStorageException { - LocalAbstractObject object = objectDistances.getObject(objectIndex); + private int getClosestNodeIndex(LocalAbstractObject object) { + double minDistance = Double.MAX_VALUE; + int closestNodeIndex = -1; + + for (int candidateIndex = 0; candidateIndex < nodes.length; candidateIndex++) { + double distance = nodes[candidateIndex].getDistance(object); - Map<Node, Double> nodeToObjectDistance = Arrays.stream(nodes) - .collect(Collectors.toMap(Function.identity(), node -> node.getDistance(object))); + if (distance < minDistance) { + minDistance = distance; + closestNodeIndex = candidateIndex; + } + } - Collections.min(nodeToObjectDistance.entrySet(), Map.Entry.comparingByValue()).getKey().addObject(object); + return closestNodeIndex; } - private int getFurthestIndex(float[][] distMatrix, BitSet notUsedIndices) { - float max = Float.MIN_VALUE; - int maxIndex = notUsedIndices.nextSetBit(0); + private int getFurthestIndex(float[][] distanceMatrix, BitSet validIndices) { + float maxDistance = Float.MIN_VALUE; + int furthestIndex = validIndices.nextSetBit(0); while (true) { - float[] distances = distMatrix[maxIndex]; - int candidateMaxIdx = this.objectDistances.maxDistInArray(distances, notUsedIndices); - if (!(distances[candidateMaxIdx] > max)) { - return maxIndex; + float[] distances = distanceMatrix[furthestIndex]; + int candidateIndex = this.objectDistances.maxDistInArray(distances, validIndices); + + if (!(distances[candidateIndex] > maxDistance)) { + return furthestIndex; } - max = distances[candidateMaxIdx]; - maxIndex = candidateMaxIdx; + maxDistance = distances[candidateIndex]; + furthestIndex = candidateIndex; } } - private Set<LocalAbstractObject> findClosestObjects(int baseObjectIndex, int numberOfObjects, BitSet notProcessedIndices) { - List<Integer> objectIndices = new ArrayList<>(1 + numberOfObjects); + private List<Integer> findClosestItems(BiFunction<List<Integer>, BitSet, Integer> findClosestItemIndex, int itemIndex, int numberOfItems, BitSet notProcessedItemIndices) { + List<Integer> itemIndices = new ArrayList<>(1 + numberOfItems); + itemIndices.add(itemIndex); - objectIndices.add(baseObjectIndex); + List<Integer> resultItemsIndices = new ArrayList<>(); - return IntStream.range(0, numberOfObjects).mapToObj(i -> { - HashMap<Integer, Float> indexToDistance = new HashMap<>(objectIndices.size()); + while (resultItemsIndices.size() != numberOfItems) { + int index = findClosestItemIndex.apply(itemIndices, notProcessedItemIndices); - for (int index : objectIndices) { - int nnIndex = objectDistances.minDistInArray(objectDistances.getDistances(index), notProcessedIndices); + itemIndices.add(index); + notProcessedItemIndices.clear(index); + resultItemsIndices.add(index); + } - float distanceSum = objectIndices.stream() - .map(objectIndex -> objectDistances.getDistance(objectIndex, nnIndex)) - .reduce(0f, Float::sum); + return resultItemsIndices; + } - indexToDistance.put(nnIndex, distanceSum); - } + private int findClosestNodeIndex(List<Integer> indices, BitSet validNodeIndices) { + double minDistance = Double.MAX_VALUE; + int closestNodeIndex = -1; - int closestPointIndex = Collections.min(indexToDistance.entrySet(), Map.Entry.comparingByValue()).getKey(); + for (int candidateIndex = validNodeIndices.nextSetBit(0); candidateIndex >= 0; candidateIndex = validNodeIndices.nextSetBit(candidateIndex + 1)) { + double sum = 0; + for (int index : indices) + sum += nodeDistances[index][candidateIndex]; - notProcessedIndices.clear(closestPointIndex); - objectIndices.add(closestPointIndex); + if (sum < minDistance) { + minDistance = sum; + closestNodeIndex = candidateIndex; + } + } - return objectDistances.getObject(closestPointIndex); - }).collect(Collectors.toSet()); + return closestNodeIndex; } - private void precomputeLeafNodeDistances() { - for (int i = 0; i < nodes.length; i++) { - for (int j = i + 1; j < nodes.length; j++) { - float distance = Float.MAX_VALUE; - - for (LocalAbstractObject firstHullObject : nodes[i].getHullObjects()) - for (LocalAbstractObject secondHullObject : nodes[j].getHullObjects()) - distance = distanceSelector.apply(distance, objectDistances.getDistance(firstHullObject, secondHullObject)); + private int findClosestObjectIndex(List<Integer> indices, BitSet validObjectIndices) { + double minDistance = Double.MAX_VALUE; + int closestObjectIndex = -1; - if (distance == 0f) - throw new RuntimeException("Zero distance between " + nodes[i].toString() + " and " + nodes[j].toString()); + for (int index : indices) { + int candidateIndex = objectDistances.minDistInArray(objectDistances.getDistances(index), validObjectIndices); + double distance = indices.stream().mapToDouble(i -> objectDistances.getDistance(i, candidateIndex)).sum(); - nodeDistances[i][j] = distance; - nodeDistances[j][i] = distance; + if (distance < minDistance) { + minDistance = distance; + closestObjectIndex = candidateIndex; } } + + return closestObjectIndex; + } + + private void mergeNodes(List<Integer> nodeIndices) { + int parentNodeIndex = nodeIndices.remove(0); + mergeNodes(parentNodeIndex, nodeIndices); } - private void mergeNodes(int mainNodeIndex, Set<Integer> nodeIndices) { + private void mergeNodes(int parentNodeIndex, List<Integer> nodeIndices) { if (nodeIndices.size() == 0) return; - Set<Integer> indices = new HashSet<>(nodeIndices); - indices.add(mainNodeIndex); + nodeIndices.add(parentNodeIndex); - Set<Node> nodes = indices.stream().map(i -> this.nodes[i]).collect(Collectors.toSet()); + List<Node> nodes = nodeIndices.stream().map(i -> this.nodes[i]).collect(Collectors.toList()); InternalNode parent = Node.createParent(nodes, objectDistances, insertType, objectToNodeDistance); nodes.forEach(node -> node.setParent(parent)); parent.addChildren(nodes); - this.nodes[mainNodeIndex] = parent; + nodeIndices.forEach(index -> { + validNodeIndices.clear(index); + this.nodes[index] = null; + }); - for (int i : nodeIndices) { - validNodeIndices.clear(i); - this.nodes[i] = null; - } + this.nodes[parentNodeIndex] = parent; + validNodeIndices.set(parentNodeIndex); - updateNodeDistances(mainNodeIndex, nodeIndices); + // Update node distances + validNodeIndices.stream().forEach(index -> computeNodeDistances(parentNodeIndex, index)); } - private void updateNodeDistances(int baseNodeIndex, Set<Integer> nodeIndices) { - if (nodeIndices.size() == 0) return; - - validNodeIndices.stream().forEach(i -> { - float distance = nodeDistances[baseNodeIndex][i]; - - for (int index : nodeIndices) - distance = distanceSelector.apply(distance, nodeDistances[index][i]); + private void computeNodeDistances(int i, int j) { + float distance = nodeToNodeDistance.getDistance(nodes[i], nodes[j], objectDistances::getDistance); - nodeDistances[baseNodeIndex][i] = distance; - nodeDistances[i][baseNodeIndex] = distance; - }); + nodeDistances[i][j] = distance; + nodeDistances[j][i] = distance; } } -- GitLab