From b4cdc7a4722298bb14baaf813ce60592644eac42 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?David=20Proch=C3=A1zka?= <david@prochazka.dev>
Date: Wed, 24 Feb 2021 13:56:36 +0100
Subject: [PATCH] ADD: visited leaves counter

- exit condition in approxKNN based on number of visited leaves
- number of nodes and hull objects in each level in tree statistics
- fixed reacall measurement
---
 src/mhtree/MHTree.java                        | 88 +++++++++++++------
 ...Measure.java => ObjectToNodeDistance.java} |  3 +-
 2 files changed, 60 insertions(+), 31 deletions(-)
 rename src/mhtree/{DistanceMeasure.java => ObjectToNodeDistance.java} (67%)

diff --git a/src/mhtree/MHTree.java b/src/mhtree/MHTree.java
index 2280488..f4c2170 100644
--- a/src/mhtree/MHTree.java
+++ b/src/mhtree/MHTree.java
@@ -5,11 +5,13 @@ import messif.buckets.BucketDispatcher;
 import messif.buckets.BucketErrorCode;
 import messif.buckets.BucketStorageException;
 import messif.buckets.LocalBucket;
-import messif.objects.AbstractObject;
 import messif.objects.LocalAbstractObject;
+import messif.operations.Approximate;
 import messif.operations.data.InsertOperation;
 import messif.operations.query.ApproxKNNQueryOperation;
 import messif.operations.query.KNNQueryOperation;
+import messif.statistics.StatisticCounter;
+import mhtree.benchmarking.Utils;
 
 import java.io.Serializable;
 import java.util.*;
@@ -27,6 +29,9 @@ public class MHTree extends Algorithm implements Serializable {
     private final Node root;
     private final BucketDispatcher bucketDispatcher;
     private final InsertType insertType;
+    private final ObjectToNodeDistance objectToNodeDistance;
+
+    private final StatisticCounter statVisitedLeaves = StatisticCounter.getStatistics("Node.Leaf.Visited");
 
     @AlgorithmConstructor(description = "MH-Tree", arguments = {
             "list of objects",
@@ -38,21 +43,24 @@ public class MHTree extends Algorithm implements Serializable {
             "storage class for buckets",
             "storage class parameters"
     })
-    public MHTree(List<LocalAbstractObject> objects, int leafCapacity, int arity, InsertType insertType, DistanceMeasure distanceMeasure, Class<? extends LocalBucket> defaultBucketClass, Map<String, Object> bucketClassParams) throws BucketStorageException {
+    public MHTree(List<LocalAbstractObject> objects, int leafCapacity, int arity, InsertType insertType, ObjectToNodeDistance objectToNodeDistance, Class<? extends LocalBucket> defaultBucketClass, Map<String, Object> bucketClassParams) throws BucketStorageException {
         super("MH-Tree");
 
         this.leafCapacity = leafCapacity;
         this.arity = arity;
         this.insertType = insertType;
+        this.objectToNodeDistance = objectToNodeDistance;
 
         bucketDispatcher = new BucketDispatcher(Integer.MAX_VALUE, Long.MAX_VALUE, leafCapacity, 0, false, defaultBucketClass, bucketClassParams);
 
-        root = new BuildTree(objects, leafCapacity, arity, insertType, distanceMeasure, bucketDispatcher).getRoot();
+        root = new BuildTree(objects, leafCapacity, arity, insertType, objectToNodeDistance, bucketDispatcher).getRoot();
     }
 
-    public void approxKNN(ApproxKNNQueryOperation operation) {
+    public void approxKNN(ApproxKNNQueryOperation operation, double coefficient) {
         LocalAbstractObject queryObject = operation.getQueryObject();
 
+        boolean limitVisitedLeaves = operation.getLocalSearchType() == Approximate.LocalSearchType.DATA_PARTITIONS;
+
         PriorityQueue<ObjectToNodeDistanceRank> queue = new PriorityQueue<>();
         queue.add(new ObjectToNodeDistanceRank(queryObject, root));
 
@@ -60,24 +68,28 @@ public class MHTree extends Algorithm implements Serializable {
             Node node = queue.poll().getNode();
 
             if (node.isLeaf()) {
-                for (LocalAbstractObject object : node.getObjects()) {
-                    if (operation.getAnswerCount() >= operation.getK() && object.getDistance(queryObject) > operation.getAnswerDistance())
-                        continue;
+                if (limitVisitedLeaves && statVisitedLeaves.get() == operation.getLocalSearchParam())
+                    break;
 
-                    operation.addToAnswer(object);
-                }
+                statVisitedLeaves.add();
 
-                if (operation.getAnswerCount() >= operation.getK())
-                    break;
+                for (LocalAbstractObject object : node.getObjects())
+                    if (!operation.isAnswerFull() || queryObject.getDistance(object) < operation.getAnswerDistance())
+                        operation.addToAnswer(object);
             } else {
                 for (Node child : ((InternalNode) node).getChildren())
-                    queue.add(new ObjectToNodeDistanceRank(queryObject, child));
+                    if (!operation.isAnswerFull() || !isPrunable(child, queryObject, operation, coefficient))
+                        queue.add(new ObjectToNodeDistanceRank(queryObject, child));
             }
         }
 
         operation.endOperation();
     }
 
+    private boolean isPrunable(Node child, LocalAbstractObject queryObject, ApproxKNNQueryOperation operation, double coefficient) {
+        return operation.getAnswerDistance() * coefficient < child.getDistanceToNearestHullObject(queryObject);
+    }
+
     public void insert(InsertOperation operation) throws BucketStorageException {
         LocalAbstractObject object = operation.getInsertedObject();
 
@@ -95,9 +107,11 @@ public class MHTree extends Algorithm implements Serializable {
     }
 
     public void printStatistics() {
-        List<Integer> objectCounts = root.getLeafNodesObjectCounts();
-
-        IntSummaryStatistics statistics = objectCounts.stream().mapToInt(Integer::valueOf).summaryStatistics();
+        IntSummaryStatistics leafStatistics = bucketDispatcher
+                .getAllBuckets()
+                .stream()
+                .mapToInt(LocalBucket::getObjectCount)
+                .summaryStatistics();
 
         System.out.println("--- STATISTICS ---");
 
@@ -108,31 +122,45 @@ public class MHTree extends Algorithm implements Serializable {
         System.out.println("Arity: " + arity);
         System.out.println("Leaf capacity: " + leafCapacity + " objects");
 
+        System.out.println("Number of nodes in each level:");
+        for (int level = 1; level <= root.getHeight() + 1; level++) {
+            System.out.println("- Level " + level + " -> " + root.getNodesOnLevel(level).size());
+        }
+
+        System.out.println("Number of hull objects in each level:");
+        for (int level = 1; level <= root.getHeight() + 1; level++) {
+            Set<Node> levelNodes = root.getNodesOnLevel(level);
+
+            System.out.println("- Level " + level + " -> " + levelNodes.stream().mapToInt(n -> n.getHullObjects().size()).summaryStatistics());
+        }
+
         System.out.println("\nHistogram of covered objects per level: ");
         System.out.println(Histogram.generate(root));
 
         System.out.println("- LeafNodes -");
 
-        System.out.println("Count: " + statistics.getCount());
-        System.out.println("Minimum number of objects: " + statistics.getMin());
-        System.out.println("Average number of objects: " + String.format("%.2f", statistics.getAverage()));
-        System.out.println("Maximum number of objects: " + statistics.getMax());
+        System.out.println("Count: " + leafStatistics.getCount());
+        System.out.println("Minimum number of objects: " + leafStatistics.getMin());
+        System.out.println("Average number of objects: " + String.format("%.2f", leafStatistics.getAverage()));
+        System.out.println("Maximum number of objects: " + leafStatistics.getMax());
     }
 
-    public double measureRecall(ApproxKNNQueryOperation approxKNNOperation) {
-        if (approxKNNOperation.getAnswerCount() == 0) return 0d;
+    public long getVisitedLeaves() {
+        return statVisitedLeaves.get();
+    }
 
-        KNNQueryOperation KNNOperation = new KNNQueryOperation(approxKNNOperation.getQueryObject(), approxKNNOperation.getK());
-        KNN(KNNOperation);
+    public void reset() {
+        statVisitedLeaves.reset();
+    }
 
-        int trueKNNCount = 0;
+    public double measureRecall(ApproxKNNQueryOperation approxKNNQueryOperation) {
+        LocalAbstractObject queryObject = approxKNNQueryOperation.getQueryObject();
+        int k = approxKNNQueryOperation.getK();
 
-        for (Iterator<AbstractObject> KNNIt = KNNOperation.getAnswerObjects(); KNNIt.hasNext(); )
-            for (Iterator<AbstractObject> approxKNNIt = approxKNNOperation.getAnswerObjects(); approxKNNIt.hasNext(); )
-                if (KNNIt.next().getLocatorURI().equals(approxKNNIt.next().getLocatorURI()))
-                    trueKNNCount++;
+        KNNQueryOperation knnQueryOperation = new KNNQueryOperation(queryObject, k);
+        KNN(knnQueryOperation);
 
-        return trueKNNCount / (double) KNNOperation.getAnswerCount();
+        return Utils.measureRecall(approxKNNQueryOperation, knnQueryOperation);
     }
 
     private void KNN(KNNQueryOperation operation) {
@@ -146,6 +174,8 @@ public class MHTree extends Algorithm implements Serializable {
                 "leafCapacity=" + leafCapacity +
                 ", arity=" + arity +
                 ", root=" + root +
+                ", insertType" + insertType +
+                ", objectToNodeDistance" + objectToNodeDistance +
                 '}';
     }
 }
diff --git a/src/mhtree/DistanceMeasure.java b/src/mhtree/ObjectToNodeDistance.java
similarity index 67%
rename from src/mhtree/DistanceMeasure.java
rename to src/mhtree/ObjectToNodeDistance.java
index 9486ec8..fc1d070 100644
--- a/src/mhtree/DistanceMeasure.java
+++ b/src/mhtree/ObjectToNodeDistance.java
@@ -3,6 +3,5 @@ package mhtree;
 public enum DistanceMeasure {
     NEAREST_HULL_OBJECT,
     FURTHEST_HULL_OBJECT,
-    SUM_OF_DISTANCES_TO_HULL_OBJECTS,
-    MEDOID
+    AVERAGE_DISTANCE
 }
-- 
GitLab