From c0a25fd8a1f82eb0bd5209ef3c345059426c79fa Mon Sep 17 00:00:00 2001
From: Vlastislav Dohnal <dohnal@fi.muni.cz>
Date: Thu, 28 Apr 2022 10:55:54 +0200
Subject: [PATCH] * Added object class parameter to benchmark * ApproxKnn oper
 changed to its predecessor Knn oper

---
 .../benchmarking/PerformanceMeasures.java     | 10 ++--
 src/mhtree/benchmarking/RunBenchmark.java     | 50 +++++++++++--------
 src/mhtree/benchmarking/Stats.java            |  2 +-
 3 files changed, 35 insertions(+), 27 deletions(-)

diff --git a/src/mhtree/benchmarking/PerformanceMeasures.java b/src/mhtree/benchmarking/PerformanceMeasures.java
index dee978a..8d61f34 100644
--- a/src/mhtree/benchmarking/PerformanceMeasures.java
+++ b/src/mhtree/benchmarking/PerformanceMeasures.java
@@ -4,7 +4,6 @@ import messif.algorithms.Algorithm;
 import messif.algorithms.AlgorithmMethodException;
 import messif.objects.util.DistanceRankedObject;
 import messif.objects.util.RankedAbstractObject;
-import messif.operations.query.ApproxKNNQueryOperation;
 import messif.operations.query.KNNQueryOperation;
 
 import java.util.ArrayList;
@@ -16,7 +15,7 @@ import java.util.stream.Collectors;
 
 public class PerformanceMeasures {
 
-    public static double measureErrorOnThePosition(ApproxKNNQueryOperation approxKNNQueryOperation, Algorithm tree) throws NoSuchMethodException, AlgorithmMethodException {
+    public static double measureErrorOnThePosition(KNNQueryOperation approxKNNQueryOperation, Algorithm tree) throws NoSuchMethodException, AlgorithmMethodException {
         KNNQueryOperation rankedObjects = new KNNQueryOperation(approxKNNQueryOperation.getQueryObject(), tree.getObjectCount());
         tree.executeOperation(rankedObjects);
 
@@ -40,7 +39,7 @@ public class PerformanceMeasures {
     }
 
     // comparing done based on distances, counts how many of the same distances of KNNQueryOperation were presents in the answer of ApproxKNNQueryOperation
-    public static double measureRecall(ApproxKNNQueryOperation approxKNNQueryOperation, Algorithm tree) throws NoSuchMethodException, AlgorithmMethodException {
+    public static double measureRecall(KNNQueryOperation approxKNNQueryOperation, Algorithm tree) throws NoSuchMethodException, AlgorithmMethodException {
         if (approxKNNQueryOperation.getAnswerCount() == 0) return 0d;
 
         KNNQueryOperation knnQueryOperation = new KNNQueryOperation(approxKNNQueryOperation.getQueryObject(), approxKNNQueryOperation.getK());
@@ -73,10 +72,11 @@ public class PerformanceMeasures {
         return trueKNNFoundCount / (double) knnQueryOperation.getAnswerCount();
     }
 
-    public static double measureRecall(ApproxKNNQueryOperation approxKNNQueryOperation, Map<String, List<RankedAbstractObject>> trueKNN) {
+    public static double measureRecall(KNNQueryOperation approxKNNQueryOperation, Map<String, List<RankedAbstractObject>> trueKNN) {
         if (approxKNNQueryOperation.getAnswerCount() == 0) return 0d;
 
-        List<RankedAbstractObject> kNNObjects = trueKNN.get(approxKNNQueryOperation.getQueryObject().getLocatorURI()).subList(0, approxKNNQueryOperation.getK());
+        List<RankedAbstractObject> kNNObjects = trueKNN.get(approxKNNQueryOperation.getQueryObject().getLocatorURI())
+                                                    .subList(0, approxKNNQueryOperation.getK());
 
         Map<Float, Long> frequencyMap = kNNObjects
                 .stream()
diff --git a/src/mhtree/benchmarking/RunBenchmark.java b/src/mhtree/benchmarking/RunBenchmark.java
index 9f1bfa9..72fe191 100644
--- a/src/mhtree/benchmarking/RunBenchmark.java
+++ b/src/mhtree/benchmarking/RunBenchmark.java
@@ -4,7 +4,6 @@ import cz.muni.fi.disa.similarityoperators.cover.AbstractRepresentation;
 import messif.algorithms.AlgorithmMethodException;
 import messif.buckets.BucketStorageException;
 import messif.objects.LocalAbstractObject;
-import messif.objects.impl.ObjectFloatVectorNeuralNetworkL2;
 import messif.objects.util.AbstractObjectList;
 import messif.objects.util.RankedAbstractObject;
 import messif.objects.util.StreamGenericAbstractObjectIterator;
@@ -24,19 +23,19 @@ import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.Collections;
 import java.util.List;
+import java.util.Locale;
 import java.util.Map;
-import java.util.logging.Level;
-import java.util.logging.Logger;
 import java.util.stream.Collectors;
 import messif.algorithms.Algorithm;
+import messif.objects.keys.AbstractObjectKey;
 
 import static mhtree.ObjectToNodeDistance.AVERAGE;
 import static mhtree.ObjectToNodeDistance.FURTHEST;
 import static mhtree.ObjectToNodeDistance.NEAREST;
 
 public class RunBenchmark {
-    public static void main(String[] args) throws IOException, AlgorithmMethodException, InstantiationException, NoSuchMethodException, BucketStorageException {
-        if (args.length != 7 && args.length != 10) {
+    public static void main(String[] args) throws IOException, AlgorithmMethodException, InstantiationException, NoSuchMethodException, BucketStorageException, ClassNotFoundException {
+        if (args.length != 8 && args.length != 11) {
             throw new IllegalArgumentException("Unexpected number of params");
         }
 
@@ -45,18 +44,21 @@ public class RunBenchmark {
             default -> false;
         };
         
+        // e.g. messif.objects.impl.ObjectFloatVectorNeuralNetworkL2
+        Class<? extends LocalAbstractObject> objClass = (Class<? extends LocalAbstractObject>) Class.forName(args[1]);   
+        
         //Statistics.enableGlobally();
         Statistics.disableGlobally();
         AbstractRepresentation.PrecomputedDistances.COMPUTATION_THREADS = 16;
 
-        List<LocalAbstractObject> objects = loadDataset(args[1]);
-        int leafCapacity = Integer.parseInt(args[2]);
-        int nodeDegree = Integer.parseInt(args[3]);
-        List<LocalAbstractObject> queries = loadDataset(args[4]);
+        List<LocalAbstractObject> objects = loadDataset(args[2], objClass);
+        int leafCapacity = Integer.parseInt(args[3]);
+        int nodeDegree = Integer.parseInt(args[4]);
+        List<LocalAbstractObject> queries = loadDataset(args[5], objClass);
         
-        InsertType insertType = args[5].equals("INCREMENTAL") ? InsertType.INCREMENTAL : InsertType.GREEDY;
+        InsertType insertType = args[6].equals("INCREMENTAL") ? InsertType.INCREMENTAL : InsertType.GREEDY;
 
-        ObjectToNodeDistance objectToNodeDistance = switch (args[6]) {
+        ObjectToNodeDistance objectToNodeDistance = switch (args[7]) {
             case "FURTHEST" -> FURTHEST;
             case "AVERAGE" -> AVERAGE;
             default -> NEAREST;
@@ -75,17 +77,17 @@ public class RunBenchmark {
                     Arrays.asList(NEAREST),
                     //Arrays.asList(NEAREST, FURTHEST, AVERAGE),
                     queries, ks);
-        } else if (args.length == 7) {      // Original M-tree
+        } else if (args.length == 8) {      // Original M-tree
             percentageToRecallMTree(cfg,
                     objects,
                     queries, ks,
                     0, 0, null);
         } else {    // Pivoting M-tree
-            List<LocalAbstractObject> pivots = loadDataset(args[9]);
+            List<LocalAbstractObject> pivots = loadDataset(args[10], objClass);
             percentageToRecallMTree(cfg,
                     objects,
                     queries, ks,
-                    Integer.parseInt(args[7]), Integer.parseInt(args[8]), pivots);
+                    Integer.parseInt(args[8]), Integer.parseInt(args[9]), pivots);
         }
     }
 
@@ -202,8 +204,10 @@ public class RunBenchmark {
         mTree.insert(opIns);
         long buildingTime = System.currentTimeMillis() - buildingStartTimeStamp;
 
+        mTree.markObjectsWithBucketIds();
         mTree.printStatistics();
-        System.out.println("Fat factor: " + mTree.getFatFactor());
+        //mTree.checkConsistency();
+        //System.out.println("Fat factor: " + mTree.getFatFactor());
         System.out.println("Building time: " + buildingTime + " msec");
 
         System.out.println("kNN queries will be executed for k=" + Arrays.toString(ks));
@@ -257,16 +261,16 @@ public class RunBenchmark {
                         "",
                         String.valueOf(k),
                         String.valueOf(percentage),
-                        String.format("%.2f,%.2f,%.2f,%.2f",
+                        String.format(Locale.ENGLISH, "%.2f,%.2f,%.2f,%.2f",
                                 recallStats.getMin(),
                                 recallStats.getAverage(),
                                 recallStats.getMedian(),
                                 recallStats.getMax()),
-                        String.format("%.2f", timeStats.getAverage())));
+                        String.format(Locale.ENGLISH, "%.2f", timeStats.getAverage())));
 
                 minimalRecall = recallStats.getMin();
-                if (minimalRecall == 1.0)
-                    break;
+                //if (minimalRecall == 1.0)
+                //    break;
             }
         }
     }
@@ -285,17 +289,21 @@ public class RunBenchmark {
                         alg.executeOperation(op);
                     } catch (AlgorithmMethodException | NoSuchMethodException ex) { }
                     KnnResultPair pair = new KnnResultPair();
+                    if (!AbstractObjectKey.class.equals(op.getQueryObject().getObjectKey().getClass()))
+                        System.out.println("ERROR: Ground truth query has a non-expected abstract key! " + op.getQueryObject().getObjectKey());
                     pair.id = op.getQueryObject().getLocatorURI();
                     pair.kNNObjects = new ArrayList<>(maxK);
                     for (RankedAbstractObject o : op)
                         pair.kNNObjects.add(o);
+                    if (pair.kNNObjects.size() != maxK)
+                        System.out.println("ERROR: Ground truth for query " + pair.id + " constains " + pair.kNNObjects.size() + " objects, Expected " + maxK);
                     return pair;
                 })
                 .collect(Collectors.toMap(entry -> entry.id, entry -> entry.kNNObjects));
         return kNNResults;
     }
 
-    private static List<LocalAbstractObject> loadDataset(String path) throws IOException {
-        return new AbstractObjectList<>(new StreamGenericAbstractObjectIterator<>(ObjectFloatVectorNeuralNetworkL2.class, path));
+    private static List<LocalAbstractObject> loadDataset(String path, Class<? extends LocalAbstractObject> objCls) throws IOException {
+        return new AbstractObjectList<>(new StreamGenericAbstractObjectIterator<>(objCls, path));
     }
 }
diff --git a/src/mhtree/benchmarking/Stats.java b/src/mhtree/benchmarking/Stats.java
index f9d793f..8b760a4 100644
--- a/src/mhtree/benchmarking/Stats.java
+++ b/src/mhtree/benchmarking/Stats.java
@@ -6,7 +6,7 @@ import java.util.List;
 public class Stats {
     List<Double> sortedData;
 
-    Stats(List<Double> data) {
+    public Stats(List<Double> data) {
         sortedData = data;
         Collections.sort(sortedData);
     }
-- 
GitLab