From 7f002a6a55cc2d82c1eedabdfea190a89f79a7ea Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?David=20Proch=C3=A1zka?= <david@prochazka.dev>
Date: Sun, 28 Feb 2021 19:39:40 +0100
Subject: [PATCH] FIX: simpified createRoot method, simplified computation of
 node distances, generalized findClosestItem

---
 src/mhtree/BuildTree.java | 222 +++++++++++++++++++-------------------
 1 file changed, 110 insertions(+), 112 deletions(-)

diff --git a/src/mhtree/BuildTree.java b/src/mhtree/BuildTree.java
index b8b7c9f..812ba70 100644
--- a/src/mhtree/BuildTree.java
+++ b/src/mhtree/BuildTree.java
@@ -5,41 +5,36 @@ import messif.buckets.BucketDispatcher;
 import messif.buckets.BucketStorageException;
 import messif.objects.LocalAbstractObject;
 
-import java.util.*;
+import java.util.ArrayList;
+import java.util.BitSet;
+import java.util.List;
 import java.util.function.BiFunction;
-import java.util.function.Function;
 import java.util.stream.Collectors;
-import java.util.stream.IntStream;
 
 class BuildTree {
-
-    private final int leafCapacity;
     private final int arity;
-
-    private final Node[] nodes;
-    private final BitSet validNodeIndices;
+    private final int leafCapacity;
+    private final InsertType insertType;
+    private final ObjectToNodeDistance objectToNodeDistance;
+    private final NodeToNodeDistance nodeToNodeDistance;
+    private final BucketDispatcher bucketDispatcher;
 
     private final PrecomputedDistances objectDistances;
     private final float[][] nodeDistances;
 
-    private final InsertType insertType;
-    private final ObjectToNodeDistance objectToNodeDistance;
-    private final BucketDispatcher bucketDispatcher;
-    private final BiFunction<Float, Float, Float> distanceSelector;
+    private final Node[] nodes;
+    private final BitSet validNodeIndices;
 
-    private Node root;
+    private final Node root;
 
     BuildTree(List<LocalAbstractObject> objects, int leafCapacity, int arity, InsertType insertType, ObjectToNodeDistance objectToNodeDistance, NodeToNodeDistance nodeToNodeDistance, BucketDispatcher bucketDispatcher) throws BucketStorageException {
-        this.arity = arity;
         this.leafCapacity = leafCapacity;
+        this.arity = arity;
         this.insertType = insertType;
         this.objectToNodeDistance = objectToNodeDistance;
+        this.nodeToNodeDistance = nodeToNodeDistance;
         this.bucketDispatcher = bucketDispatcher;
 
-        // Set distance selector function for two hull objects
-        boolean selectNearest = nodeToNodeDistance == NodeToNodeDistance.NEAREST_HULL_OBJECTS;
-        distanceSelector = selectNearest ? Math::min : Math::max;
-
         nodes = new Node[objects.size() / leafCapacity];
 
         validNodeIndices = new BitSet(nodes.length);
@@ -48,55 +43,38 @@ class BuildTree {
         objectDistances = new PrecomputedDistances(objects);
         nodeDistances = new float[nodes.length][nodes.length];
 
-        // Every object is stored in the root
-        if (objectDistances.getObjectCount() < leafCapacity) {
+        // Every object is stored in the root in the case of small number of objects
+        if (objectDistances.getObjectCount() <= leafCapacity) {
             root = new LeafNode(objectDistances, bucketDispatcher.createBucket(), insertType, objectToNodeDistance);
             return;
         }
 
         createLeafNodes();
-
-        precomputeLeafNodeDistances();
-
-        buildTree();
+        root = createRoot();
     }
 
     public Node getRoot() {
         return root;
     }
 
-    private void buildTree() {
+    private Node createRoot() {
         while (validNodeIndices.cardinality() != 1) {
             BitSet notProcessedNodeIndices = (BitSet) validNodeIndices.clone();
 
             while (!notProcessedNodeIndices.isEmpty()) {
-                if (notProcessedNodeIndices.cardinality() < arity) {
-                    Set<Integer> nodeIndices = new HashSet<>(notProcessedNodeIndices.cardinality() - 1);
-
-                    int mainNodeIndex = notProcessedNodeIndices.nextSetBit(0);
-                    notProcessedNodeIndices.stream().skip(1).forEach(nodeIndices::add);
-
-                    mergeNodes(mainNodeIndex, nodeIndices);
-
+                if (notProcessedNodeIndices.cardinality() <= arity) {
+                    mergeNodes(notProcessedNodeIndices.stream().boxed().collect(Collectors.toList()));
                     break;
                 }
 
                 int furthestNodeIndex = getFurthestIndex(nodeDistances, notProcessedNodeIndices);
                 notProcessedNodeIndices.clear(furthestNodeIndex);
 
-                Set<Integer> nnNodeIndices = new HashSet<>(arity - 1);
-
-                for (int i = 0; i < arity - 1; i++) {
-                    int index = objectDistances.minDistInArrayExceptIdx(nodeDistances[furthestNodeIndex], notProcessedNodeIndices, furthestNodeIndex);
-                    notProcessedNodeIndices.clear(index);
-                    nnNodeIndices.add(index);
-                }
-
-                mergeNodes(furthestNodeIndex, nnNodeIndices);
+                mergeNodes(furthestNodeIndex, findClosestItems(this::findClosestNodeIndex, furthestNodeIndex, arity - 1, notProcessedNodeIndices));
             }
         }
 
-        root = nodes[validNodeIndices.nextSetBit(0)];
+        return nodes[validNodeIndices.nextSetBit(0)];
     }
 
     private void createLeafNodes() throws BucketStorageException {
@@ -105,128 +83,148 @@ class BuildTree {
 
         for (int nodeIndex = 0; !notProcessedObjectIndices.isEmpty(); nodeIndex++) {
             if (notProcessedObjectIndices.cardinality() < leafCapacity) {
-                for (int i = notProcessedObjectIndices.nextSetBit(0); i >= 0; i = notProcessedObjectIndices.nextSetBit(i + 1))
-                    addObjectToClosestNode(i);
+                for (int i = notProcessedObjectIndices.nextSetBit(0); i >= 0; i = notProcessedObjectIndices.nextSetBit(i + 1)) {
+                    LocalAbstractObject object = objectDistances.getObject(i);
+
+                    nodes[getClosestNodeIndex(object)].addObject(object);
+                }
+
                 return;
             }
 
-            List<LocalAbstractObject> objects = new ArrayList<>(leafCapacity);
+            List<Integer> objectIndices = new ArrayList<>(leafCapacity);
 
             // Select a base object
             int furthestIndex = getFurthestIndex(objectDistances.getDistances(), notProcessedObjectIndices);
             notProcessedObjectIndices.clear(furthestIndex);
-            objects.add(objectDistances.getObject(furthestIndex));
+            objectIndices.add(furthestIndex);
 
             // Select the rest of the objects up to the total of leafCapacity
-            objects.addAll(findClosestObjects(furthestIndex, leafCapacity - 1, notProcessedObjectIndices));
+            objectIndices.addAll(findClosestItems(this::findClosestObjectIndex, furthestIndex, leafCapacity - 1, notProcessedObjectIndices));
+
+            List<LocalAbstractObject> objects = objectIndices.stream().map(objectDistances::getObject).collect(Collectors.toList());
 
             nodes[nodeIndex] = new LeafNode(objectDistances.getSubset(objects), bucketDispatcher.createBucket(), insertType, objectToNodeDistance);
         }
     }
 
-    private void addObjectToClosestNode(int objectIndex) throws BucketStorageException {
-        LocalAbstractObject object = objectDistances.getObject(objectIndex);
+    private int getClosestNodeIndex(LocalAbstractObject object) {
+        double minDistance = Double.MAX_VALUE;
+        int closestNodeIndex = -1;
+
+        for (int candidateIndex = 0; candidateIndex < nodes.length; candidateIndex++) {
+            double distance = nodes[candidateIndex].getDistance(object);
 
-        Map<Node, Double> nodeToObjectDistance = Arrays.stream(nodes)
-                .collect(Collectors.toMap(Function.identity(), node -> node.getDistance(object)));
+            if (distance < minDistance) {
+                minDistance = distance;
+                closestNodeIndex = candidateIndex;
+            }
+        }
 
-        Collections.min(nodeToObjectDistance.entrySet(), Map.Entry.comparingByValue()).getKey().addObject(object);
+        return closestNodeIndex;
     }
 
-    private int getFurthestIndex(float[][] distMatrix, BitSet notUsedIndices) {
-        float max = Float.MIN_VALUE;
-        int maxIndex = notUsedIndices.nextSetBit(0);
+    private int getFurthestIndex(float[][] distanceMatrix, BitSet validIndices) {
+        float maxDistance = Float.MIN_VALUE;
+        int furthestIndex = validIndices.nextSetBit(0);
 
         while (true) {
-            float[] distances = distMatrix[maxIndex];
-            int candidateMaxIdx = this.objectDistances.maxDistInArray(distances, notUsedIndices);
-            if (!(distances[candidateMaxIdx] > max)) {
-                return maxIndex;
+            float[] distances = distanceMatrix[furthestIndex];
+            int candidateIndex = this.objectDistances.maxDistInArray(distances, validIndices);
+
+            if (!(distances[candidateIndex] > maxDistance)) {
+                return furthestIndex;
             }
 
-            max = distances[candidateMaxIdx];
-            maxIndex = candidateMaxIdx;
+            maxDistance = distances[candidateIndex];
+            furthestIndex = candidateIndex;
         }
     }
 
-    private Set<LocalAbstractObject> findClosestObjects(int baseObjectIndex, int numberOfObjects, BitSet notProcessedIndices) {
-        List<Integer> objectIndices = new ArrayList<>(1 + numberOfObjects);
+    private List<Integer> findClosestItems(BiFunction<List<Integer>, BitSet, Integer> findClosestItemIndex, int itemIndex, int numberOfItems, BitSet notProcessedItemIndices) {
+        List<Integer> itemIndices = new ArrayList<>(1 + numberOfItems);
+        itemIndices.add(itemIndex);
 
-        objectIndices.add(baseObjectIndex);
+        List<Integer> resultItemsIndices = new ArrayList<>();
 
-        return IntStream.range(0, numberOfObjects).mapToObj(i -> {
-            HashMap<Integer, Float> indexToDistance = new HashMap<>(objectIndices.size());
+        while (resultItemsIndices.size() != numberOfItems) {
+            int index = findClosestItemIndex.apply(itemIndices, notProcessedItemIndices);
 
-            for (int index : objectIndices) {
-                int nnIndex = objectDistances.minDistInArray(objectDistances.getDistances(index), notProcessedIndices);
+            itemIndices.add(index);
+            notProcessedItemIndices.clear(index);
+            resultItemsIndices.add(index);
+        }
 
-                float distanceSum = objectIndices.stream()
-                        .map(objectIndex -> objectDistances.getDistance(objectIndex, nnIndex))
-                        .reduce(0f, Float::sum);
+        return resultItemsIndices;
+    }
 
-                indexToDistance.put(nnIndex, distanceSum);
-            }
+    private int findClosestNodeIndex(List<Integer> indices, BitSet validNodeIndices) {
+        double minDistance = Double.MAX_VALUE;
+        int closestNodeIndex = -1;
 
-            int closestPointIndex = Collections.min(indexToDistance.entrySet(), Map.Entry.comparingByValue()).getKey();
+        for (int candidateIndex = validNodeIndices.nextSetBit(0); candidateIndex >= 0; candidateIndex = validNodeIndices.nextSetBit(candidateIndex + 1)) {
+            double sum = 0;
+            for (int index : indices)
+                sum += nodeDistances[index][candidateIndex];
 
-            notProcessedIndices.clear(closestPointIndex);
-            objectIndices.add(closestPointIndex);
+            if (sum < minDistance) {
+                minDistance = sum;
+                closestNodeIndex = candidateIndex;
+            }
+        }
 
-            return objectDistances.getObject(closestPointIndex);
-        }).collect(Collectors.toSet());
+        return closestNodeIndex;
     }
 
-    private void precomputeLeafNodeDistances() {
-        for (int i = 0; i < nodes.length; i++) {
-            for (int j = i + 1; j < nodes.length; j++) {
-                float distance = Float.MAX_VALUE;
-
-                for (LocalAbstractObject firstHullObject : nodes[i].getHullObjects())
-                    for (LocalAbstractObject secondHullObject : nodes[j].getHullObjects())
-                        distance = distanceSelector.apply(distance, objectDistances.getDistance(firstHullObject, secondHullObject));
+    private int findClosestObjectIndex(List<Integer> indices, BitSet validObjectIndices) {
+        double minDistance = Double.MAX_VALUE;
+        int closestObjectIndex = -1;
 
-                if (distance == 0f)
-                    throw new RuntimeException("Zero distance between " + nodes[i].toString() + " and " + nodes[j].toString());
+        for (int index : indices) {
+            int candidateIndex = objectDistances.minDistInArray(objectDistances.getDistances(index), validObjectIndices);
+            double distance = indices.stream().mapToDouble(i -> objectDistances.getDistance(i, candidateIndex)).sum();
 
-                nodeDistances[i][j] = distance;
-                nodeDistances[j][i] = distance;
+            if (distance < minDistance) {
+                minDistance = distance;
+                closestObjectIndex = candidateIndex;
             }
         }
+
+        return closestObjectIndex;
+    }
+
+    private void mergeNodes(List<Integer> nodeIndices) {
+        int parentNodeIndex = nodeIndices.remove(0);
+        mergeNodes(parentNodeIndex, nodeIndices);
     }
 
-    private void mergeNodes(int mainNodeIndex, Set<Integer> nodeIndices) {
+    private void mergeNodes(int parentNodeIndex, List<Integer> nodeIndices) {
         if (nodeIndices.size() == 0) return;
 
-        Set<Integer> indices = new HashSet<>(nodeIndices);
-        indices.add(mainNodeIndex);
+        nodeIndices.add(parentNodeIndex);
 
-        Set<Node> nodes = indices.stream().map(i -> this.nodes[i]).collect(Collectors.toSet());
+        List<Node> nodes = nodeIndices.stream().map(i -> this.nodes[i]).collect(Collectors.toList());
 
         InternalNode parent = Node.createParent(nodes, objectDistances, insertType, objectToNodeDistance);
         nodes.forEach(node -> node.setParent(parent));
         parent.addChildren(nodes);
 
-        this.nodes[mainNodeIndex] = parent;
+        nodeIndices.forEach(index -> {
+            validNodeIndices.clear(index);
+            this.nodes[index] = null;
+        });
 
-        for (int i : nodeIndices) {
-            validNodeIndices.clear(i);
-            this.nodes[i] = null;
-        }
+        this.nodes[parentNodeIndex] = parent;
+        validNodeIndices.set(parentNodeIndex);
 
-        updateNodeDistances(mainNodeIndex, nodeIndices);
+        // Update node distances
+        validNodeIndices.stream().forEach(index -> computeNodeDistances(parentNodeIndex, index));
     }
 
-    private void updateNodeDistances(int baseNodeIndex, Set<Integer> nodeIndices) {
-        if (nodeIndices.size() == 0) return;
-
-        validNodeIndices.stream().forEach(i -> {
-            float distance = nodeDistances[baseNodeIndex][i];
-
-            for (int index : nodeIndices)
-                distance = distanceSelector.apply(distance, nodeDistances[index][i]);
+    private void computeNodeDistances(int i, int j) {
+        float distance = nodeToNodeDistance.getDistance(nodes[i], nodes[j], objectDistances::getDistance);
 
-            nodeDistances[baseNodeIndex][i] = distance;
-            nodeDistances[i][baseNodeIndex] = distance;
-        });
+        nodeDistances[i][j] = distance;
+        nodeDistances[j][i] = distance;
     }
 }
-- 
GitLab