From d0e726c7334fd7bc4c696c8417349a2513d3bd16 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?David=20Proch=C3=A1zka?= <david@prochazka.dev>
Date: Fri, 19 Feb 2021 09:49:35 +0100
Subject: [PATCH] ADD: more explicit tree building steps

---
 src/mhtree/BuildTree.java | 90 +++++++++++++++++++--------------------
 1 file changed, 44 insertions(+), 46 deletions(-)

diff --git a/src/mhtree/BuildTree.java b/src/mhtree/BuildTree.java
index 42e1f09..f6f7c97 100644
--- a/src/mhtree/BuildTree.java
+++ b/src/mhtree/BuildTree.java
@@ -30,6 +30,9 @@ class BuildTree {
     BuildTree(List<LocalAbstractObject> objects, int leafCapacity, int arity, InsertType insertType, DistanceMeasure distanceMeasure, BucketDispatcher bucketDispatcher) throws BucketStorageException {
         this.arity = arity;
         this.leafCapacity = leafCapacity;
+        this.insertType = insertType;
+        this.distanceMeasure = distanceMeasure;
+        this.bucketDispatcher = bucketDispatcher;
 
         nodes = new Node[objects.size() / leafCapacity];
 
@@ -39,9 +42,15 @@ class BuildTree {
         objectDistances = new PrecomputedDistances(objects);
         nodeDistances = new float[nodes.length][nodes.length];
 
-        this.insertType = insertType;
-        this.distanceMeasure = distanceMeasure;
-        this.bucketDispatcher = bucketDispatcher;
+        // Every object is stored in the root
+        if (objectDistances.getObjectCount() < leafCapacity) {
+            root = new LeafNode(objectDistances, bucketDispatcher.createBucket(), insertType, distanceMeasure);
+            return;
+        }
+
+        createLeafNodes();
+
+        precomputeNodeDistances();
 
         buildTree();
     }
@@ -50,50 +59,41 @@ class BuildTree {
         return root;
     }
 
-    private void buildTree() throws BucketStorageException {
-        if (!initHullPoints()) return;
-
-        precomputeHullDistances();
-
+    private void buildTree() {
         while (validNodeIndices.cardinality() != 1) {
-            BitSet notProcessedIndices = (BitSet) validNodeIndices.clone();
+            BitSet notProcessedNodeIndices = (BitSet) validNodeIndices.clone();
 
-            while (!notProcessedIndices.isEmpty()) {
-                if (notProcessedIndices.cardinality() < arity) {
-                    List<Integer> restOfTheIndices = new ArrayList<>();
+            while (!notProcessedNodeIndices.isEmpty()) {
+                if (notProcessedNodeIndices.cardinality() < arity) {
+                    List<Integer> nodeIndices = new ArrayList<>();
 
-                    notProcessedIndices.stream().forEach(restOfTheIndices::add);
+                    notProcessedNodeIndices.stream().forEach(nodeIndices::add);
 
-                    int mainIndex = restOfTheIndices.remove(0);
-                    mergeHulls(mainIndex, restOfTheIndices);
+                    int mainNodeIndex = nodeIndices.remove(0);
+                    mergeNodes(mainNodeIndex, nodeIndices);
 
                     break;
                 }
 
-                int furthestNodeIndex = getFurthestIndex(nodeDistances, notProcessedIndices);
-                notProcessedIndices.clear(furthestNodeIndex);
+                int furthestNodeIndex = getFurthestIndex(nodeDistances, notProcessedNodeIndices);
+                notProcessedNodeIndices.clear(furthestNodeIndex);
 
-                List<Integer> nnIndices = new ArrayList<>();
+                List<Integer> nnNodeIndices = new ArrayList<>();
 
                 for (int i = 0; i < arity - 1; i++) {
-                    int index = objectDistances.minDistInArrayExceptIdx(nodeDistances[furthestNodeIndex], notProcessedIndices, furthestNodeIndex);
-                    notProcessedIndices.clear(index);
-                    nnIndices.add(index);
+                    int index = objectDistances.minDistInArrayExceptIdx(nodeDistances[furthestNodeIndex], notProcessedNodeIndices, furthestNodeIndex);
+                    notProcessedNodeIndices.clear(index);
+                    nnNodeIndices.add(index);
                 }
 
-                mergeHulls(furthestNodeIndex, nnIndices);
+                mergeNodes(furthestNodeIndex, nnNodeIndices);
             }
         }
 
         root = nodes[validNodeIndices.nextSetBit(0)];
     }
 
-    private boolean initHullPoints() throws BucketStorageException {
-        if (objectDistances.getObjectCount() < leafCapacity) {
-            root = new LeafNode(objectDistances, bucketDispatcher.createBucket(), insertType, distanceMeasure);
-            return false;
-        }
-
+    private void createLeafNodes() throws BucketStorageException {
         BitSet notProcessedObjectIndices = new BitSet(objectDistances.getObjectCount());
         notProcessedObjectIndices.set(0, objectDistances.getObjectCount());
 
@@ -101,7 +101,7 @@ class BuildTree {
             if (notProcessedObjectIndices.cardinality() < leafCapacity) {
                 for (int i = notProcessedObjectIndices.nextSetBit(0); i >= 0; i = notProcessedObjectIndices.nextSetBit(i + 1))
                     addObjectToClosestNode(i);
-                return true;
+                return;
             }
 
             List<LocalAbstractObject> objects = new ArrayList<>();
@@ -116,15 +116,13 @@ class BuildTree {
 
             nodes[nodeIndex] = new LeafNode(objectDistances.getSubset(objects), bucketDispatcher.createBucket(), insertType, distanceMeasure);
         }
-
-        return true;
     }
 
     private void addObjectToClosestNode(int objectIndex) throws BucketStorageException {
         LocalAbstractObject object = objectDistances.getObject(objectIndex);
 
         Function<Node, Float> getMinHullObjectDistance = node -> node.getHullObjects().stream()
-                .map(obj -> obj.getDistance(object))
+                .map(object::getDistance)
                 .reduce(Float.MAX_VALUE, Math::min);
 
         Map<Node, Float> nodeToObjectDistance = Arrays.stream(nodes)
@@ -177,7 +175,7 @@ class BuildTree {
         }).collect(Collectors.toList());
     }
 
-    private void precomputeHullDistances() {
+    private void precomputeNodeDistances() {
         for (int i = 0; i < nodes.length; i++) {
             for (int j = i + 1; j < nodes.length; j++) {
                 float minDistance = Float.MAX_VALUE;
@@ -195,11 +193,11 @@ class BuildTree {
         }
     }
 
-    private void mergeHulls(int mainHullIndex, List<Integer> otherHullIndices) {
-        if (otherHullIndices.size() == 0) return;
+    private void mergeNodes(int mainNodeIndex, List<Integer> nodeIndices) {
+        if (nodeIndices.size() == 0) return;
 
-        List<Integer> indices = new ArrayList<>(otherHullIndices);
-        indices.add(mainHullIndex);
+        List<Integer> indices = new ArrayList<>(nodeIndices);
+        indices.add(mainNodeIndex);
 
         List<Node> nodes = indices.stream().map(i -> this.nodes[i]).collect(Collectors.toList());
 
@@ -207,27 +205,27 @@ class BuildTree {
         nodes.forEach(node -> node.setParent(parent));
         parent.addChildren(nodes);
 
-        this.nodes[mainHullIndex] = parent;
+        this.nodes[mainNodeIndex] = parent;
 
-        for (int i : otherHullIndices) {
+        for (int i : nodeIndices) {
             validNodeIndices.clear(i);
             this.nodes[i] = null;
         }
 
-        updateHullDistances(mainHullIndex, otherHullIndices);
+        updateNodeDistances(mainNodeIndex, nodeIndices);
     }
 
-    private void updateHullDistances(int baseHullIndex, List<Integer> otherHullIndices) {
-        if (otherHullIndices.size() == 0) return;
+    private void updateNodeDistances(int baseNodeIndex, List<Integer> nodeIndices) {
+        if (nodeIndices.size() == 0) return;
 
         validNodeIndices.stream().forEach(i -> {
-            float minDistance = nodeDistances[baseHullIndex][i];
+            float minDistance = nodeDistances[baseNodeIndex][i];
 
-            for (int index : otherHullIndices)
+            for (int index : nodeIndices)
                 minDistance = Math.min(minDistance, nodeDistances[index][i]);
 
-            nodeDistances[baseHullIndex][i] = minDistance;
-            nodeDistances[i][baseHullIndex] = minDistance;
+            nodeDistances[baseNodeIndex][i] = minDistance;
+            nodeDistances[i][baseNodeIndex] = minDistance;
         });
     }
 }
-- 
GitLab