From dd40600b22cb8c133cc41fed1f4370ed59166e09 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?David=20Proch=C3=A1zka?= <david@prochazka.dev>
Date: Wed, 21 Apr 2021 11:51:43 +0200
Subject: [PATCH] ADD: M-Tree benchmark logic

---
 src/mhtree/Node.java                      | 14 ++---
 src/mhtree/benchmarking/RunBenchmark.java | 76 ++++++++++++++++++++++-
 2 files changed, 80 insertions(+), 10 deletions(-)

diff --git a/src/mhtree/Node.java b/src/mhtree/Node.java
index ddaef35..600e80f 100644
--- a/src/mhtree/Node.java
+++ b/src/mhtree/Node.java
@@ -18,8 +18,8 @@ public abstract class Node implements Serializable {
      */
     private static final long serialVersionUID = 420L;
 
-    private final InsertType INSERT_TYPE;
-    private final ObjectToNodeDistance OBJECT_TO_NODE_DISTANCE;
+    private final InsertType insertType;
+    private final ObjectToNodeDistance objectToNodeDistance;
 
     private HullOptimizedRepresentationV3 hull;
 
@@ -27,8 +27,8 @@ public abstract class Node implements Serializable {
         this.hull = new HullOptimizedRepresentationV3(distances);
         this.hull.build();
 
-        this.INSERT_TYPE = insertType;
-        this.OBJECT_TO_NODE_DISTANCE = objectToNodeDistance;
+        this.insertType = insertType;
+        this.objectToNodeDistance = objectToNodeDistance;
     }
 
     protected static InternalNode createParent(List<Node> nodes, PrecomputedDistances distances, InsertType insertType, ObjectToNodeDistance objectToNodeDistance, MergeType mergeType) {
@@ -70,11 +70,11 @@ public abstract class Node implements Serializable {
     }
 
     protected double getDistance(LocalAbstractObject object) {
-        return OBJECT_TO_NODE_DISTANCE.getDistance(object, this);
+        return objectToNodeDistance.getDistance(object, this);
     }
 
     protected double getDistance(LocalAbstractObject object, PrecomputedDistances distances) {
-        return OBJECT_TO_NODE_DISTANCE.getDistance(object, this, distances);
+        return objectToNodeDistance.getDistance(object, this, distances);
     }
 
     protected double getDistanceToNearest(LocalAbstractObject object) {
@@ -100,7 +100,7 @@ public abstract class Node implements Serializable {
     protected void addObjectIntoHull(LocalAbstractObject object, PrecomputedDistances distances) {
         if (isCovered(object, distances)) return;
 
-        if (INSERT_TYPE == InsertType.INCREMENTAL) {
+        if (insertType == InsertType.INCREMENTAL) {
             hull.addHullObject(object);
             return;
         }
diff --git a/src/mhtree/benchmarking/RunBenchmark.java b/src/mhtree/benchmarking/RunBenchmark.java
index 1dc0fc4..aa723ff 100644
--- a/src/mhtree/benchmarking/RunBenchmark.java
+++ b/src/mhtree/benchmarking/RunBenchmark.java
@@ -1,6 +1,7 @@
 package mhtree.benchmarking;
 
 import cz.muni.fi.disa.similarityoperators.cover.AbstractRepresentation;
+import messif.algorithms.AlgorithmMethodException;
 import messif.buckets.BucketStorageException;
 import messif.objects.LocalAbstractObject;
 import messif.objects.impl.ObjectFloatVectorNeuralNetworkL2;
@@ -8,6 +9,7 @@ import messif.objects.util.AbstractObjectList;
 import messif.objects.util.RankedAbstractObject;
 import messif.objects.util.StreamGenericAbstractObjectIterator;
 import messif.operations.Approximate;
+import messif.operations.data.BulkInsertOperation;
 import messif.operations.query.ApproxKNNQueryOperation;
 import messif.operations.query.KNNQueryOperation;
 import messif.statistics.Statistics;
@@ -15,16 +17,18 @@ import mhtree.InsertType;
 import mhtree.MHTree;
 import mhtree.MergeType;
 import mhtree.ObjectToNodeDistance;
+import mtree.MTree;
 
 import java.io.IOException;
 import java.util.ArrayList;
 import java.util.Arrays;
+import java.util.Collections;
 import java.util.List;
 import java.util.Map;
 import java.util.stream.Collectors;
 
 public class RunBenchmark {
-    public static void main(String[] args) throws IOException, BucketStorageException {
+    public static void main(String[] args) throws IOException, AlgorithmMethodException, InstantiationException, NoSuchMethodException, BucketStorageException {
         if (args.length != 5) {
             throw new IllegalArgumentException("Unexpected number of params");
         }
@@ -50,7 +54,7 @@ public class RunBenchmark {
                 break;
         }
 
-        percentageToRecall(new MHTreeConfig(
+        percentageToRecallMHTree(new MHTreeConfig(
                         leafCapacity,
                         nodeDegree,
                         insertType,
@@ -61,7 +65,7 @@ public class RunBenchmark {
         );
     }
 
-    private static void percentageToRecall(MHTreeConfig config, List<LocalAbstractObject> objects, int[] ks) throws BucketStorageException, RuntimeException {
+    private static void percentageToRecallMHTree(MHTreeConfig config, List<LocalAbstractObject> objects, int[] ks) throws BucketStorageException, RuntimeException {
         MHTree mTree = new MHTree.Builder(objects, config.leafCapacity, config.nodeDegree)
                 .objectToNodeDistance(config.objectToNodeDistance)
                 .mergeType(MergeType.REPRESENTATION_BASED)
@@ -145,6 +149,72 @@ public class RunBenchmark {
         }
     }
 
+    private static void percentageToRecallMTree(MHTreeConfig config, List<LocalAbstractObject> objects, int[] ks) throws NoSuchMethodException, AlgorithmMethodException, RuntimeException, InstantiationException {
+        int numberOfObjects = objects.size();
+
+        MTree mTree = new MTree(config.nodeDegree, config.leafCapacity);
+
+        Collections.shuffle(objects);
+
+        BulkInsertOperation op = new BulkInsertOperation(objects);
+
+        mTree.insert(op);
+
+        mTree.printStatistics();
+
+        System.out.println("leafCapacity,nodeDegree,objectToNodeDistance,k,percentage,recall (min),recall (avg),recall (med),recall (max)");
+
+        double minimalRecall = 0;
+        int percentage = 0;
+        int percentageStep = 5;
+
+        for (int k : ks) {
+            List<Double> recalls = new ArrayList<>(numberOfObjects);
+            for (int i = 0; i < numberOfObjects; i++) {
+                recalls.add(0.0);
+            }
+
+            while (minimalRecall != 1.0) {
+                for (int i = 0; i < numberOfObjects; i++) {
+                    if (recalls.get(i) != 1.0) {
+                        ApproxKNNQueryOperation operation = new ApproxKNNQueryOperation(
+                                objects.get(i),
+                                k,
+                                percentage,
+                                Approximate.LocalSearchType.PERCENTAGE,
+                                LocalAbstractObject.UNKNOWN_DISTANCE
+                        );
+
+                        mTree.executeOperation(operation);
+
+                        recalls.set(i, PerformanceMeasures.measureRecall(operation, mTree));
+                    }
+                }
+
+                Stats recallStats = new Stats(new ArrayList<>(recalls));
+
+                System.out.println(String.join(",",
+                        String.valueOf(config.leafCapacity),
+                        String.valueOf(config.nodeDegree),
+                        String.valueOf(config.objectToNodeDistance),
+                        String.valueOf(k),
+                        String.valueOf(percentage),
+                        String.format("%.2f,%.2f,%.2f,%.2f",
+                                recallStats.getMin(),
+                                recallStats.getAverage(),
+                                recallStats.getMedian(),
+                                recallStats.getMax())));
+
+                minimalRecall = recallStats.getMin();
+                percentage += percentageStep;
+            }
+
+            minimalRecall = 0;
+            percentage = 0;
+        }
+    }
+
+
     private static List<LocalAbstractObject> loadDataset(String path) throws IOException {
         return new AbstractObjectList<>(new StreamGenericAbstractObjectIterator<>(ObjectFloatVectorNeuralNetworkL2.class, path));
     }
-- 
GitLab