From c0aa6a1b68c714aa4249eb4421f93c4dab7ca2df Mon Sep 17 00:00:00 2001
From: Vlastislav Dohnal <dohnal@fi.muni.cz>
Date: Fri, 25 Jun 2021 16:55:23 +0200
Subject: [PATCH] Minor refactoring of implementation and javadoc update *
 RunBenchmark improved by calling Fat Factor and making M-tree code parallel.

---
 src/mhtree/InsertType.java                   |   7 +-
 src/mhtree/Node.java                         |   7 +-
 src/mhtree/ObjectToNodeDistance.java         |   1 -
 src/mhtree/ObjectToNodeDistanceRank.java     |   5 +
 src/mhtree/benchmarking/KnnResultPair.java   |  15 +
 src/mhtree/benchmarking/RunBenchmark.java    | 278 ++++++++++++-------
 src/mhtree/benchmarking/SearchInfoState.java |  17 ++
 src/mhtree/benchmarking/SearchState.java     |   4 +-
 8 files changed, 225 insertions(+), 109 deletions(-)
 create mode 100644 src/mhtree/benchmarking/KnnResultPair.java
 create mode 100644 src/mhtree/benchmarking/SearchInfoState.java

diff --git a/src/mhtree/InsertType.java b/src/mhtree/InsertType.java
index a57b43b..50f392f 100644
--- a/src/mhtree/InsertType.java
+++ b/src/mhtree/InsertType.java
@@ -5,18 +5,21 @@ package mhtree;
  */
 public enum InsertType {
     /**
-     * When the inserted object is not covered by a node, all objects under such node are retrieved,
+     * When the inserted object is not covered by a node, all objects under such node are retrieved (recursively down to the buckets),
      * and a new hull is built, replacing the current one.
      */
     GREEDY,
 
     /**
      * When the inserted object is not covered by node, we iterate over hull objects beginning with the nearest one.
-     * We try to replace the hull object by removing it from the hull and replacing it with the inserted object.
+     * We try to replace an existing hull object by replacing it with the inserted object.
      * If the removed hull object is covered by the new hull, we are done.
      * If no such hull object is found, the inserted object is simply added as a new hull object.
      */
     INCREMENTAL,
 
+    /** Take current hull objects and the newly inserted object and compute hull out of them.
+     * Vlasta: Is this identical to INCREMENTAL?
+     */
     ADD_HULL_OBJECT,
 }
diff --git a/src/mhtree/Node.java b/src/mhtree/Node.java
index 4fb191f..3990bd9 100644
--- a/src/mhtree/Node.java
+++ b/src/mhtree/Node.java
@@ -7,6 +7,7 @@ import messif.objects.LocalAbstractObject;
 
 import java.io.Serializable;
 import java.util.Collection;
+import java.util.Collections;
 import java.util.List;
 import java.util.stream.Collectors;
 
@@ -36,7 +37,8 @@ public abstract class Node implements Serializable {
         this.objectToNodeDistance = objectToNodeDistance;
     }
 
-    protected static InternalNode createParent(List<Node> nodes, PrecomputedDistances distances, InsertType insertType, ObjectToNodeDistance objectToNodeDistance, MergingMethod mergingMethod) {
+    protected static InternalNode createParent(List<Node> nodes, PrecomputedDistances distances, InsertType insertType, 
+                                                ObjectToNodeDistance objectToNodeDistance, MergingMethod mergingMethod) {
         List<LocalAbstractObject> objects = nodes
                 .stream()
                 .map(mergingMethod::getObjects)
@@ -169,7 +171,8 @@ public abstract class Node implements Serializable {
     }
 
     private void insertIncremental(LocalAbstractObject object) {
-        hull.addHullObject(object);
+        //hull.addHullObject(object);
+        hull.setHullObjects((List<LocalAbstractObject>) Collections.singleton(object));
     }
 
     private void insertHullRebuild(LocalAbstractObject object) {
diff --git a/src/mhtree/ObjectToNodeDistance.java b/src/mhtree/ObjectToNodeDistance.java
index 42fc7c0..fd8af4e 100644
--- a/src/mhtree/ObjectToNodeDistance.java
+++ b/src/mhtree/ObjectToNodeDistance.java
@@ -52,7 +52,6 @@ public enum ObjectToNodeDistance {
                     .min()
                     .orElse(Double.MAX_VALUE);
         }
-
     };
 
     /**
diff --git a/src/mhtree/ObjectToNodeDistanceRank.java b/src/mhtree/ObjectToNodeDistanceRank.java
index 792089f..946ff48 100644
--- a/src/mhtree/ObjectToNodeDistanceRank.java
+++ b/src/mhtree/ObjectToNodeDistanceRank.java
@@ -45,4 +45,9 @@ public class ObjectToNodeDistanceRank implements Comparable<ObjectToNodeDistance
     public Node getNode() {
         return node;
     }
+
+    public double getDistance() {
+        return distance;
+    }
+    
 }
diff --git a/src/mhtree/benchmarking/KnnResultPair.java b/src/mhtree/benchmarking/KnnResultPair.java
new file mode 100644
index 0000000..d1e196b
--- /dev/null
+++ b/src/mhtree/benchmarking/KnnResultPair.java
@@ -0,0 +1,15 @@
+/*
+ * To change this license header, choose License Headers in Project Properties.
+ * To change this template file, choose Tools | Templates
+ * and open the template in the editor.
+ */
+package mhtree.benchmarking;
+
+import java.util.List;
+import messif.objects.util.RankedAbstractObject;
+
+public class KnnResultPair {
+    public List<RankedAbstractObject> kNNObjects;
+    public String id;
+}
+
diff --git a/src/mhtree/benchmarking/RunBenchmark.java b/src/mhtree/benchmarking/RunBenchmark.java
index ebf9d50..9cbfc6b 100644
--- a/src/mhtree/benchmarking/RunBenchmark.java
+++ b/src/mhtree/benchmarking/RunBenchmark.java
@@ -25,7 +25,10 @@ import java.util.Arrays;
 import java.util.Collections;
 import java.util.List;
 import java.util.Map;
+import java.util.logging.Level;
+import java.util.logging.Logger;
 import java.util.stream.Collectors;
+import messif.algorithms.Algorithm;
 
 import static mhtree.ObjectToNodeDistance.AVERAGE;
 import static mhtree.ObjectToNodeDistance.FURTHEST;
@@ -33,84 +36,83 @@ import static mhtree.ObjectToNodeDistance.NEAREST;
 
 public class RunBenchmark {
     public static void main(String[] args) throws IOException, AlgorithmMethodException, InstantiationException, NoSuchMethodException, BucketStorageException {
-        if (args.length != 5) {
+        if (args.length != 7) {
             throw new IllegalArgumentException("Unexpected number of params");
         }
 
-        Statistics.enableGlobally();
+        boolean isMHtree = switch (args[0]) {
+            case "MHULL-TREE" -> true;
+            default -> false;
+        };
+        
+        //Statistics.enableGlobally();
+        Statistics.disableGlobally();
         AbstractRepresentation.PrecomputedDistances.COMPUTATION_THREADS = 16;
 
-        List<LocalAbstractObject> objects = loadDataset(args[0]);
-        int leafCapacity = Integer.parseInt(args[1]);
-        int nodeDegree = Integer.parseInt(args[2]);
-        InsertType insertType = args[3].equals("INCREMENTAL") ? InsertType.INCREMENTAL : InsertType.GREEDY;
-        ObjectToNodeDistance objectToNodeDistance;
-
-        switch (args[4]) {
-            case "FURTHEST":
-                objectToNodeDistance = FURTHEST;
-                break;
-            case "AVERAGE":
-                objectToNodeDistance = AVERAGE;
-                break;
-            default:
-                objectToNodeDistance = NEAREST;
-                break;
-        }
-
-        percentageToRecallMHTree(new MHTreeConfig(
-                        leafCapacity,
-                        nodeDegree,
-                        insertType,
-                        objectToNodeDistance
-                ),
-                objects,
-                new int[]{1, 50}
+        List<LocalAbstractObject> objects = loadDataset(args[1]);
+        int leafCapacity = Integer.parseInt(args[2]);
+        int nodeDegree = Integer.parseInt(args[3]);
+        List<LocalAbstractObject> queries = loadDataset(args[4]);
+        
+        InsertType insertType = args[5].equals("INCREMENTAL") ? InsertType.INCREMENTAL : InsertType.GREEDY;
+
+        ObjectToNodeDistance objectToNodeDistance = switch (args[6]) {
+            case "FURTHEST" -> FURTHEST;
+            case "AVERAGE" -> AVERAGE;
+            default -> NEAREST;
+        };
+        
+        final MHTreeConfig cfg = new MHTreeConfig(
+                leafCapacity,
+                nodeDegree,
+                insertType,
+                objectToNodeDistance
         );
+        final int[] ks = new int[]{1, 3, 5, 10, 20, 50, 100};
+        if (isMHtree) {
+            percentageToRecallMHTree(cfg,
+                    objects,
+                    Arrays.asList(NEAREST),
+                    //Arrays.asList(NEAREST, FURTHEST, AVERAGE),
+                    queries, ks);
+        } else {
+            percentageToRecallMTree(cfg,
+                    objects,
+                    queries, ks);
+        }
     }
 
-    private static void percentageToRecallMHTree(MHTreeConfig config, List<LocalAbstractObject> objects, int[] ks) throws BucketStorageException, RuntimeException {
+    private static void percentageToRecallMHTree(MHTreeConfig config, List<LocalAbstractObject> objects, 
+            final List<ObjectToNodeDistance> distFuncs,
+            List<LocalAbstractObject> queries,
+            int[] ks) throws BucketStorageException, RuntimeException {
+
+        long buildingStartTimeStamp = System.currentTimeMillis();
+
         MHTree mhTree = new MHTree.MHTreeBuilder(objects, config.leafCapacity, config.nodeDegree)
                 .objectToNodeDistance(config.objectToNodeDistance)
                 .mergingMethod(MergingMethod.HULL_BASED_MERGE)
                 .build();
 
+        long buildinTime = System.currentTimeMillis() - buildingStartTimeStamp;
+        //OperationStatistics.getLocalThreadStatistics().printStatistics();
         mhTree.printStatistics();
+        //System.out.println("Fat factor: " + mhTree.getFatFactor());
+        System.out.println("Building time: " + buildinTime + " msec");
 
-        for (ObjectToNodeDistance dist : Arrays.asList(NEAREST, FURTHEST, AVERAGE)) {
+        System.out.println("kNN queries will be executed for k=" + Arrays.toString(ks));
+        
+        for (ObjectToNodeDistance dist : distFuncs) {
             System.gc();
 
             mhTree.getNodes().forEach(node -> node.objectToNodeDistance = dist);
 
-            System.out.println("leafCapacity,nodeDegree,objectToNodeDistance,k,percentage,recall (min),recall (avg),recall (med),recall (max)");
-
-            int maxK = Arrays.stream(ks).max().getAsInt();
-
-            List<KNNQueryOperation> kNNOperations = objects
-                    .parallelStream()
-                    .map(object -> new KNNQueryOperation(object, maxK))
-                    .collect(Collectors.toList());
+            System.out.println("leafCapacity,nodeDegree,objectToNodeDistance,k,percentage,recall (min),recall (avg),recall (med),recall (max), time (msec)");
 
-            class Pair {
-                public List<RankedAbstractObject> kNNObjects;
-                public String id;
-            }
-
-            Map<String, List<RankedAbstractObject>> kNNResults = kNNOperations
-                    .parallelStream()
-                    .map(op -> {
-                        mhTree.kNN(op);
-                        Pair pair = new Pair();
-                        pair.id = op.getQueryObject().getLocatorURI();
-                        pair.kNNObjects = new ArrayList<>(maxK);
-                        for (RankedAbstractObject o : op)
-                            pair.kNNObjects.add(o);
-                        return pair;
-                    })
-                    .collect(Collectors.toMap(entry -> entry.id, entry -> entry.kNNObjects));
+            Map<String, List<RankedAbstractObject>> kNNResults = prepareGroundTruth(ks, queries, mhTree);
 
             for (int k : ks) {
-                List<ApproxKNNQueryOperation> approxOperations = objects
+                List<ApproxKNNQueryOperation> approxOperations = queries
                         .parallelStream()
                         .map(object -> new ApproxKNNQueryOperation(object, k, 0, Approximate.LocalSearchType.PERCENTAGE, LocalAbstractObject.UNKNOWN_DISTANCE))
                         .collect(Collectors.toList());
@@ -119,18 +121,22 @@ public class RunBenchmark {
                         .forEach(op -> op.suppData = new SearchState(mhTree, op));
 
                 double minimalRecall = 0;
-                int percentage = 0;
-                int percentageStep = 5;
-
-                while (minimalRecall != 1.0) {
+                for (int percentage = 0; percentage <= 100; percentage += 5) {
+                    final int approxLimit = Math.round((float) objects.size() * (float) percentage / 100f);
                     approxOperations
                             .parallelStream()
                             .filter(op -> ((SearchState) op.suppData).recall != 1d)
                             .forEach(op -> {
                                 SearchState searchState = (SearchState) op.suppData;
-                                mhTree.approxKNN(op);
+                                try {
+                                    //mhTree.approxKNN(op);
+                                    mhTree.executeOperation(op);
+                                } catch (AlgorithmMethodException | NoSuchMethodException ex) {
+                                }
+                                searchState.time = op.getParameter("OperationTime", Long.class);
                                 searchState.recall = PerformanceMeasures.measureRecall(op, kNNResults);
-                                searchState.approximateState.limit += Math.round((float) objects.size() * (float) percentageStep / 100f);
+                                searchState.approximateState.limit = approxLimit;
+                                op.resetAnswer();
                             });
 
                     Stats recallStats = new Stats(
@@ -139,6 +145,12 @@ public class RunBenchmark {
                                     .map(op -> ((SearchState) op.suppData).recall)
                                     .collect(Collectors.toList())
                     );
+                    Stats timeStats = new Stats(
+                            approxOperations
+                                    .stream()
+                                    .map(op -> (double)((SearchState) op.suppData).time)
+                                    .collect(Collectors.toList())
+                    );
 
                     System.out.println(String.join(",",
                             String.valueOf(config.leafCapacity),
@@ -150,80 +162,144 @@ public class RunBenchmark {
                                     recallStats.getMin(),
                                     recallStats.getAverage(),
                                     recallStats.getMedian(),
-                                    recallStats.getMax())));
+                                    recallStats.getMax()),
+                            String.format("%.2f", timeStats.getAverage())));
 
                     minimalRecall = recallStats.getMin();
-                    percentage += percentageStep;
+                    if (minimalRecall == 1.0)
+                        break;
                 }
             }
         }
     }
 
-    private static void percentageToRecallMTree(MHTreeConfig config, List<LocalAbstractObject> objects, int[] ks) throws NoSuchMethodException, AlgorithmMethodException, RuntimeException, InstantiationException {
-        int numberOfObjects = objects.size();
+    private static void percentageToRecallMTree(MHTreeConfig config, List<LocalAbstractObject> objects, 
+            List<LocalAbstractObject> queries,
+            int[] ks) throws NoSuchMethodException, AlgorithmMethodException, RuntimeException, InstantiationException {
+
+        long buildingStartTimeStamp = System.currentTimeMillis();
 
         MTree mTree = new MTree(config.nodeDegree, config.leafCapacity);
 
         Collections.shuffle(objects);
 
-        BulkInsertOperation op = new BulkInsertOperation(objects);
+        BulkInsertOperation opIns = new BulkInsertOperation(objects);
 
-        mTree.insert(op);
+        mTree.insert(opIns);
+        long buildingTime = System.currentTimeMillis() - buildingStartTimeStamp;
 
         mTree.printStatistics();
+        //System.out.println("Fat factor: " + mTree.getFatFactor());
+        System.out.println("Building time: " + buildingTime + " msec");
 
-        System.out.println("leafCapacity,nodeDegree,objectToNodeDistance,k,percentage,recall (min),recall (avg),recall (med),recall (max)");
+        System.out.println("kNN queries will be executed for k=" + Arrays.toString(ks));
+        
+        System.out.println("leafCapacity,nodeDegree,objectToNodeDistance,k,percentage,recall (min),recall (avg),recall (med),recall (max), time (msec)");
 
-        double minimalRecall = 0;
-        int percentage = 0;
-        int percentageStep = 5;
+        Map<String, List<RankedAbstractObject>> kNNResults = prepareGroundTruth(ks, queries, mTree);
+        
 
+//        int numberOfQueries = queries.size();
         for (int k : ks) {
-            List<Double> recalls = new ArrayList<>(numberOfObjects);
-            for (int i = 0; i < numberOfObjects; i++) {
-                recalls.add(0.0);
-            }
-
-            while (minimalRecall != 1.0) {
-                for (int i = 0; i < numberOfObjects; i++) {
-                    if (recalls.get(i) != 1.0) {
-                        ApproxKNNQueryOperation operation = new ApproxKNNQueryOperation(
-                                objects.get(i),
-                                k,
-                                percentage,
-                                Approximate.LocalSearchType.PERCENTAGE,
-                                LocalAbstractObject.UNKNOWN_DISTANCE
-                        );
-
-                        mTree.executeOperation(operation);
-
-                        recalls.set(i, PerformanceMeasures.measureRecall(operation, mTree));
-                    }
-                }
-
-                Stats recallStats = new Stats(new ArrayList<>(recalls));
+            double minimalRecall = 0;
+            for (int percentage = 0; percentage <= 100; percentage += 5) {
+                final int approxLimit = percentage;
+                List<ApproxKNNQueryOperation> approxOperations = queries
+                        .parallelStream()
+                        .map(object -> new ApproxKNNQueryOperation(object, k, approxLimit, Approximate.LocalSearchType.PERCENTAGE, LocalAbstractObject.UNKNOWN_DISTANCE))
+                        .collect(Collectors.toList());
+                approxOperations
+                        .parallelStream()
+                        .forEach(op -> op.suppData = new SearchInfoState());
+                
+                approxOperations
+                        .parallelStream()
+                        .forEach(op -> {
+                            SearchInfoState searchState = (SearchInfoState) op.suppData;
+                            try {
+                                //mhTree.approxKNN(op);
+                                mTree.executeOperation(op);
+                            } catch (AlgorithmMethodException | NoSuchMethodException ex) {
+                            }
+                            searchState.time = op.getParameter("OperationTime", Long.class);
+                            searchState.recall = PerformanceMeasures.measureRecall(op, kNNResults);
+                            //searchState.approximateState.limit += Math.round((float) objects.size() * (float) percentageStep / 100f);
+                        });
+
+                Stats recallStats = new Stats(
+                        approxOperations
+                                .stream()
+                                .map(op -> ((SearchInfoState) op.suppData).recall)
+                                .collect(Collectors.toList())
+                );
+                Stats timeStats = new Stats(
+                        approxOperations
+                                .stream()
+                                .map(op -> (double)((SearchInfoState) op.suppData).time)
+                                .collect(Collectors.toList())
+                );
+                
+//                for (int i = 0; i < numberOfQueries; i++) {
+//                    if (recalls.get(i) != 1.0) {
+//                        ApproxKNNQueryOperation operation = new ApproxKNNQueryOperation(
+//                                queries.get(i),
+//                                k,
+//                                percentage,
+//                                Approximate.LocalSearchType.PERCENTAGE,
+//                                LocalAbstractObject.UNKNOWN_DISTANCE
+//                        );
+//
+//                        mTree.executeOperation(operation);
+//
+//                        times.set(i, (double)operation.getParameter("OperationTime", Long.class));
+//                        recalls.set(i, PerformanceMeasures.measureRecall(operation, mTree));
+//                    }
+//                }
 
                 System.out.println(String.join(",",
                         String.valueOf(config.leafCapacity),
                         String.valueOf(config.nodeDegree),
-                        String.valueOf(config.objectToNodeDistance),
+                        "",
                         String.valueOf(k),
                         String.valueOf(percentage),
                         String.format("%.2f,%.2f,%.2f,%.2f",
                                 recallStats.getMin(),
                                 recallStats.getAverage(),
                                 recallStats.getMedian(),
-                                recallStats.getMax())));
+                                recallStats.getMax()),
+                        String.format("%.2f", timeStats.getAverage())));
 
                 minimalRecall = recallStats.getMin();
-                percentage += percentageStep;
+                if (minimalRecall == 1.0)
+                    break;
             }
-
-            minimalRecall = 0;
-            percentage = 0;
         }
     }
 
+    private static Map<String, List<RankedAbstractObject>> prepareGroundTruth(int[] ks, List<LocalAbstractObject> queries, Algorithm alg) {
+        int maxK = Arrays.stream(ks).max().getAsInt();
+        
+        List<KNNQueryOperation> kNNOperations = queries
+                .parallelStream()
+                .map(object -> new KNNQueryOperation(object, maxK))
+                .collect(Collectors.toList());
+        Map<String, List<RankedAbstractObject>> kNNResults = kNNOperations
+                .parallelStream()
+                .map((KNNQueryOperation op) -> {
+                    try {
+                        alg.executeOperation(op);
+                    } catch (AlgorithmMethodException | NoSuchMethodException ex) { }
+                    KnnResultPair pair = new KnnResultPair();
+                    pair.id = op.getQueryObject().getLocatorURI();
+                    pair.kNNObjects = new ArrayList<>(maxK);
+                    for (RankedAbstractObject o : op)
+                        pair.kNNObjects.add(o);
+                    return pair;
+                })
+                .collect(Collectors.toMap(entry -> entry.id, entry -> entry.kNNObjects));
+        return kNNResults;
+    }
+
     private static List<LocalAbstractObject> loadDataset(String path) throws IOException {
         return new AbstractObjectList<>(new StreamGenericAbstractObjectIterator<>(ObjectFloatVectorNeuralNetworkL2.class, path));
     }
diff --git a/src/mhtree/benchmarking/SearchInfoState.java b/src/mhtree/benchmarking/SearchInfoState.java
new file mode 100644
index 0000000..95caf51
--- /dev/null
+++ b/src/mhtree/benchmarking/SearchInfoState.java
@@ -0,0 +1,17 @@
+/*
+ * To change this license header, choose License Headers in Project Properties.
+ * To change this template file, choose Tools | Templates
+ * and open the template in the editor.
+ */
+package mhtree.benchmarking;
+
+public class SearchInfoState {
+    public double recall;
+    public long time;   // in msec
+
+    public SearchInfoState() {
+        this.recall = 0d;
+        this.time = 0;
+    }
+    
+}
diff --git a/src/mhtree/benchmarking/SearchState.java b/src/mhtree/benchmarking/SearchState.java
index 6ff8c6f..05057ae 100644
--- a/src/mhtree/benchmarking/SearchState.java
+++ b/src/mhtree/benchmarking/SearchState.java
@@ -8,17 +8,15 @@ import mhtree.ObjectToNodeDistanceRank;
 
 import java.util.PriorityQueue;
 
-public class SearchState {
+public class SearchState extends SearchInfoState {
     public PriorityQueue<ObjectToNodeDistanceRank> queue;
     public LocalAbstractObject queryObject;
     public ApproximateState approximateState;
-    public double recall;
 
     public SearchState(MHTree tree, ApproxKNNQueryOperation operation) {
         this.queue = new PriorityQueue<>();
         this.queue.add(new ObjectToNodeDistanceRank(operation.getQueryObject(), tree.getRoot(), operation.getK()));
         this.queryObject = operation.getQueryObject();
         this.approximateState = ApproximateState.create(operation, tree);
-        this.recall = 0d;
     }
 }
-- 
GitLab