From 24024c78eb60bb88fd87219bacafeb6a3394f8cd Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?David=20Proch=C3=A1zka?= <david@prochazka.dev>
Date: Thu, 28 Jan 2021 10:19:18 +0100
Subject: [PATCH] ADD: subset of precomputed distances is used to speed up
 creation of a new node

---
 src/mhtree/BuildTree.java                    | 16 ++---
 src/mhtree/InternalNode.java                 |  5 +-
 src/mhtree/LeafNode.java                     | 12 ++--
 src/mhtree/MHTree.java                       |  4 +-
 src/mhtree/Node.java                         | 22 ++++---
 src/mhtree/benchmarking/BenchmarkConfig.java |  6 +-
 src/mhtree/benchmarking/MHTreeConfig.java    |  3 +-
 src/mhtree/benchmarking/RunBenchmark.java    | 62 ++++++++++++++++----
 8 files changed, 84 insertions(+), 46 deletions(-)

diff --git a/src/mhtree/BuildTree.java b/src/mhtree/BuildTree.java
index fc36b71..ddab38a 100644
--- a/src/mhtree/BuildTree.java
+++ b/src/mhtree/BuildTree.java
@@ -1,11 +1,14 @@
 package mhtree;
 
-import cz.muni.fi.disa.similarityoperators.cover.AbstractRepresentation;
+import cz.muni.fi.disa.similarityoperators.cover.AbstractRepresentation.PrecomputedDistances;
 import messif.buckets.BucketDispatcher;
 import messif.buckets.BucketStorageException;
 import messif.objects.LocalAbstractObject;
 
-import java.util.*;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.BitSet;
+import java.util.List;
 import java.util.function.Function;
 import java.util.stream.Collectors;
 
@@ -17,7 +20,7 @@ class BuildTree {
     private final Node[] nodes;
     private final BitSet validNodeIndices;
 
-    private final AbstractRepresentation.PrecomputedDistances objectDistances;
+    private final PrecomputedDistances objectDistances;
     private final float[][] nodeDistances;
 
     private final BucketDispatcher bucketDispatcher;
@@ -33,7 +36,7 @@ class BuildTree {
         validNodeIndices = new BitSet(nodes.length);
         validNodeIndices.set(0, nodes.length);
 
-        objectDistances = new AbstractRepresentation.PrecomputedDistances(objects);
+        objectDistances = new PrecomputedDistances(objects);
         nodeDistances = new float[nodes.length][nodes.length];
 
         this.bucketDispatcher = bucketDispatcher;
@@ -110,7 +113,7 @@ class BuildTree {
             findClosestObjectIndices(furthestIndex, leafCapacity - 1, notProcessedObjectIndices)
                     .forEach(i -> objects.add(objectDistances.getObject(i)));
 
-            nodes[nodeIndex] = new LeafNode(objects, bucketDispatcher.createBucket(), insertType);
+            nodes[nodeIndex] = new LeafNode(objectDistances.getSubset(objects), bucketDispatcher.createBucket(), insertType);
         }
 
         return true;
@@ -127,7 +130,6 @@ class BuildTree {
                 .map(getMinHullObjectDistance)
                 .collect(Collectors.toList());
 
-
         Node node = nodes[Utils.getIndexOfMinElement(hullObjectDistances)];
         node.addObject(object);
     }
@@ -206,7 +208,7 @@ class BuildTree {
 
         List<Node> nodes = indices.stream().map(i -> this.nodes[i]).collect(Collectors.toList());
 
-        InternalNode parent = Node.createParent(nodes, insertType);
+        InternalNode parent = Node.createParent(nodes, objectDistances, insertType);
         nodes.forEach(node -> node.setParent(parent));
         parent.addChildren(nodes);
 
diff --git a/src/mhtree/InternalNode.java b/src/mhtree/InternalNode.java
index 2d6dd3d..65195d6 100644
--- a/src/mhtree/InternalNode.java
+++ b/src/mhtree/InternalNode.java
@@ -1,5 +1,6 @@
 package mhtree;
 
+import cz.muni.fi.disa.similarityoperators.cover.AbstractRepresentation.PrecomputedDistances;
 import messif.objects.LocalAbstractObject;
 
 import java.io.Serializable;
@@ -17,8 +18,8 @@ public class InternalNode extends Node implements Serializable {
 
     private final List<Node> children;
 
-    InternalNode(List<LocalAbstractObject> objects, InsertType insertType) {
-        super(objects, insertType);
+    InternalNode(PrecomputedDistances distances, InsertType insertType) {
+        super(distances, insertType);
         children = new ArrayList<>();
     }
 
diff --git a/src/mhtree/LeafNode.java b/src/mhtree/LeafNode.java
index 6f94ab0..0378200 100644
--- a/src/mhtree/LeafNode.java
+++ b/src/mhtree/LeafNode.java
@@ -1,6 +1,6 @@
 package mhtree;
 
-import cz.muni.fi.disa.similarityoperators.cover.AbstractRepresentation;
+import cz.muni.fi.disa.similarityoperators.cover.AbstractRepresentation.PrecomputedDistances;
 import messif.buckets.BucketStorageException;
 import messif.buckets.LocalBucket;
 import messif.objects.LocalAbstractObject;
@@ -18,15 +18,11 @@ public class LeafNode extends Node implements Serializable {
     private static final long serialVersionUID = 1L;
     private LocalBucket bucket;
 
-    LeafNode(List<LocalAbstractObject> objects, LocalBucket bucket, InsertType insertType) throws BucketStorageException {
-        super(objects, insertType);
+    LeafNode(PrecomputedDistances distances, LocalBucket bucket, InsertType insertType) throws BucketStorageException {
+        super(distances, insertType);
 
         this.bucket = bucket;
-        this.bucket.addObjects(objects);
-    }
-
-    LeafNode(AbstractRepresentation.PrecomputedDistances distances, LocalBucket bucket, InsertType insertType) throws BucketStorageException {
-        this(distances.getObjects(), bucket, insertType);
+        this.bucket.addObjects(distances.getObjects());
     }
 
     public void addObject(LocalAbstractObject object) throws BucketStorageException {
diff --git a/src/mhtree/MHTree.java b/src/mhtree/MHTree.java
index 99fd3e0..ba9beff 100644
--- a/src/mhtree/MHTree.java
+++ b/src/mhtree/MHTree.java
@@ -1,6 +1,6 @@
 package mhtree;
 
-import cz.muni.fi.disa.similarityoperators.cover.AbstractRepresentation;
+import cz.muni.fi.disa.similarityoperators.cover.AbstractRepresentation.PrecomputedDistances;
 import messif.algorithms.Algorithm;
 import messif.buckets.BucketDispatcher;
 import messif.buckets.BucketErrorCode;
@@ -46,7 +46,7 @@ public class MHTree extends Algorithm implements Serializable {
         this.insertType = insertType;
         bucketDispatcher = new BucketDispatcher(Integer.MAX_VALUE, Long.MAX_VALUE, leafCapacity, 0, false, defaultBucketClass, bucketClassParams);
 
-        AbstractRepresentation.PrecomputedDistances.COMPUTATION_THREADS = numberOfThreads;
+        PrecomputedDistances.COMPUTATION_THREADS = numberOfThreads;
 
         root = new BuildTree(objects, leafCapacity, numberOfChildren, insertType, bucketDispatcher).getRoot();
     }
diff --git a/src/mhtree/Node.java b/src/mhtree/Node.java
index 995d10e..c8e2b55 100644
--- a/src/mhtree/Node.java
+++ b/src/mhtree/Node.java
@@ -1,5 +1,6 @@
 package mhtree;
 
+import cz.muni.fi.disa.similarityoperators.cover.AbstractRepresentation.PrecomputedDistances;
 import cz.muni.fi.disa.similarityoperators.cover.HullOptimizedRepresentationV3;
 import messif.buckets.BucketStorageException;
 import messif.objects.LocalAbstractObject;
@@ -15,25 +16,22 @@ public abstract class Node implements Serializable {
      * Serialization ID
      */
     private static final long serialVersionUID = 420L;
-
-    protected HullOptimizedRepresentationV3 hull;
     protected final InsertType insertType;
+    protected HullOptimizedRepresentationV3 hull;
     private Node parent;
 
-    Node(HullOptimizedRepresentationV3 hull, InsertType insertType) {
-        hull.build();
-        this.hull = hull;
+    Node(PrecomputedDistances distances, InsertType insertType) {
+        this.hull = new HullOptimizedRepresentationV3(distances);
+        this.hull.build();
         this.insertType = insertType;
     }
 
-    Node(List<LocalAbstractObject> objects, InsertType insertType) {
-        this(new HullOptimizedRepresentationV3(objects), insertType);
-    }
-
-    public static InternalNode createParent(List<Node> nodes, InsertType insertType) {
-        return new InternalNode(nodes.stream()
+    public static InternalNode createParent(List<Node> nodes, PrecomputedDistances distances, InsertType insertType) {
+        List<LocalAbstractObject> objects = nodes.stream()
                 .flatMap(node -> node.getObjects().stream())
-                .collect(Collectors.toList()), insertType);
+                .collect(Collectors.toList());
+
+        return new InternalNode(distances.getSubset(objects), insertType);
     }
 
     public float getDistance(LocalAbstractObject object) {
diff --git a/src/mhtree/benchmarking/BenchmarkConfig.java b/src/mhtree/benchmarking/BenchmarkConfig.java
index 17bdfd9..647e6f7 100644
--- a/src/mhtree/benchmarking/BenchmarkConfig.java
+++ b/src/mhtree/benchmarking/BenchmarkConfig.java
@@ -39,8 +39,10 @@ public class BenchmarkConfig {
             for (MHTreeConfig mhTreeConfig : mhTreeConfigs) {
                 mhTreeConfig.setObjects(objects);
 
-                for (Consumer<MHTreeConfig> benchmarkFunction : benchmarkFunctions)
-                    benchmarkFunction.accept(mhTreeConfig);
+                for (int i = 0; i < benchmarkFunctions.size(); i++) {
+                    System.out.println("~Running" + mhTreeConfig + " on " + i + "-th function");
+                    benchmarkFunctions.get(i).accept(mhTreeConfig);
+                }
             }
         }
     }
diff --git a/src/mhtree/benchmarking/MHTreeConfig.java b/src/mhtree/benchmarking/MHTreeConfig.java
index 5a90644..7ff5f04 100644
--- a/src/mhtree/benchmarking/MHTreeConfig.java
+++ b/src/mhtree/benchmarking/MHTreeConfig.java
@@ -46,8 +46,7 @@ public class MHTreeConfig {
     @Override
     public String toString() {
         return "MHTreeConfig{" +
-                "objects=" + objects +
-                ", leafCapacity=" + leafCapacity +
+                "leafCapacity=" + leafCapacity +
                 ", numberOfChildren=" + numberOfChildren +
                 ", numberOfThreads=" + numberOfThreads +
                 ", insertType=" + insertType +
diff --git a/src/mhtree/benchmarking/RunBenchmark.java b/src/mhtree/benchmarking/RunBenchmark.java
index 9b121fe..9a0994a 100644
--- a/src/mhtree/benchmarking/RunBenchmark.java
+++ b/src/mhtree/benchmarking/RunBenchmark.java
@@ -3,6 +3,7 @@ package mhtree.benchmarking;
 import messif.buckets.BucketStorageException;
 import messif.buckets.impl.MemoryStorageBucket;
 import messif.objects.LocalAbstractObject;
+import messif.operations.data.InsertOperation;
 import messif.operations.query.ApproxKNNQueryOperation;
 import mhtree.InsertType;
 import mhtree.MHTree;
@@ -31,11 +32,53 @@ public class RunBenchmark {
                 datasets,
                 Arrays.asList(
                         RunBenchmark::benchBuildTree,
-                        (MHTreeConfig config) -> benchApproxKNN(config, 10),
-                        (MHTreeConfig config) -> benchApproxKNN(config, 30)
+                        (MHTreeConfig config) -> benchApproxKNN(config, 30),
+                        RunBenchmark::benchInsert
                 )).execute();
     }
 
+    public static void benchInsert(MHTreeConfig config) {
+        try {
+            List<LocalAbstractObject> objects = config.getObjects();
+            List<LocalAbstractObject> initialObjects = objects.subList(0, objects.size() / 2);
+            List<LocalAbstractObject> restOfTheObjects = objects.subList(objects.size() / 2, objects.size());
+
+            long timeStart = System.currentTimeMillis();
+
+            MHTree tree = new MHTree(
+                    initialObjects,
+                    config.getLeafCapacity(),
+                    config.getNumberOfChildren(),
+                    config.getNumberOfThreads(),
+                    config.getInsertType(),
+                    MemoryStorageBucket.class,
+                    null);
+
+            long treeBuilt = System.currentTimeMillis();
+
+            long tookMax = Long.MIN_VALUE;
+            long tookMin = Long.MAX_VALUE;
+
+            for (LocalAbstractObject object : restOfTheObjects) {
+                long insertStart = System.currentTimeMillis();
+                tree.insert(new InsertOperation(object));
+                long took = System.currentTimeMillis() - insertStart;
+
+                tookMax = Math.max(tookMax, took);
+                tookMin = Math.min(tookMin, took);
+            }
+
+            long timeStop = System.currentTimeMillis();
+
+            System.out.println("~Building took: " + (treeBuilt - timeStart) + " ms");
+            System.out.println("~max: " + tookMax + " ms");
+            System.out.println("~min: " + tookMin + " ms");
+            System.out.println("~Took: " + (timeStop - timeStart) + " ms");
+        } catch (BucketStorageException ex) {
+            Logger.getLogger("BuildTree").log(Level.SEVERE, null, ex);
+        }
+    }
+
     public static void benchApproxKNN(MHTreeConfig config, int k) {
         try {
             MHTree tree = new MHTree(config.getObjects(),
@@ -46,8 +89,6 @@ public class RunBenchmark {
                     MemoryStorageBucket.class,
                     null);
 
-            System.out.println("Tree built");
-
             long timeStart = System.currentTimeMillis();
 
             long tookMax = Long.MIN_VALUE;
@@ -56,8 +97,7 @@ public class RunBenchmark {
             for (LocalAbstractObject object : config.getObjects()) {
                 long nnStart = System.currentTimeMillis();
 
-                ApproxKNNQueryOperation op = new ApproxKNNQueryOperation(object, k);
-                tree.approxKNN(op);
+                tree.approxKNN(new ApproxKNNQueryOperation(object, k));
 
                 long took = System.currentTimeMillis() - nnStart;
 
@@ -67,11 +107,11 @@ public class RunBenchmark {
 
             long timeStop = System.currentTimeMillis();
 
-            System.out.println("nn took: " + (timeStop - timeStart) + " ms");
-            System.out.println("max: " + tookMax + " ms");
-            System.out.println("min: " + tookMin + " ms");
+            System.out.println("~nn took: " + (timeStop - timeStart) + " ms");
+            System.out.println("~max: " + tookMax + " ms");
+            System.out.println("~min: " + tookMin + " ms");
 
-            System.out.println(config + " k: " + k);
+            System.out.println("~" + config + " k: " + k);
 
         } catch (BucketStorageException ex) {
             Logger.getLogger("BuildTree").log(Level.SEVERE, null, ex);
@@ -91,7 +131,7 @@ public class RunBenchmark {
                     null);
             long timeStop = System.currentTimeMillis();
 
-            System.out.println("Building took: " + (timeStop - timeStart) + " ms");
+            System.out.println("~Building took: " + (timeStop - timeStart) + " ms");
         } catch (BucketStorageException ex) {
             Logger.getLogger("BuildTree").log(Level.SEVERE, null, ex);
         }
-- 
GitLab