From b4cdc7a4722298bb14baaf813ce60592644eac42 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?David=20Proch=C3=A1zka?= <david@prochazka.dev> Date: Wed, 24 Feb 2021 13:56:36 +0100 Subject: [PATCH] ADD: visited leaves counter - exit condition in approxKNN based on number of visited leaves - number of nodes and hull objects in each level in tree statistics - fixed reacall measurement --- src/mhtree/MHTree.java | 88 +++++++++++++------ ...Measure.java => ObjectToNodeDistance.java} | 3 +- 2 files changed, 60 insertions(+), 31 deletions(-) rename src/mhtree/{DistanceMeasure.java => ObjectToNodeDistance.java} (67%) diff --git a/src/mhtree/MHTree.java b/src/mhtree/MHTree.java index 2280488..f4c2170 100644 --- a/src/mhtree/MHTree.java +++ b/src/mhtree/MHTree.java @@ -5,11 +5,13 @@ import messif.buckets.BucketDispatcher; import messif.buckets.BucketErrorCode; import messif.buckets.BucketStorageException; import messif.buckets.LocalBucket; -import messif.objects.AbstractObject; import messif.objects.LocalAbstractObject; +import messif.operations.Approximate; import messif.operations.data.InsertOperation; import messif.operations.query.ApproxKNNQueryOperation; import messif.operations.query.KNNQueryOperation; +import messif.statistics.StatisticCounter; +import mhtree.benchmarking.Utils; import java.io.Serializable; import java.util.*; @@ -27,6 +29,9 @@ public class MHTree extends Algorithm implements Serializable { private final Node root; private final BucketDispatcher bucketDispatcher; private final InsertType insertType; + private final ObjectToNodeDistance objectToNodeDistance; + + private final StatisticCounter statVisitedLeaves = StatisticCounter.getStatistics("Node.Leaf.Visited"); @AlgorithmConstructor(description = "MH-Tree", arguments = { "list of objects", @@ -38,21 +43,24 @@ public class MHTree extends Algorithm implements Serializable { "storage class for buckets", "storage class parameters" }) - public MHTree(List<LocalAbstractObject> objects, int leafCapacity, int arity, InsertType insertType, DistanceMeasure distanceMeasure, Class<? extends LocalBucket> defaultBucketClass, Map<String, Object> bucketClassParams) throws BucketStorageException { + public MHTree(List<LocalAbstractObject> objects, int leafCapacity, int arity, InsertType insertType, ObjectToNodeDistance objectToNodeDistance, Class<? extends LocalBucket> defaultBucketClass, Map<String, Object> bucketClassParams) throws BucketStorageException { super("MH-Tree"); this.leafCapacity = leafCapacity; this.arity = arity; this.insertType = insertType; + this.objectToNodeDistance = objectToNodeDistance; bucketDispatcher = new BucketDispatcher(Integer.MAX_VALUE, Long.MAX_VALUE, leafCapacity, 0, false, defaultBucketClass, bucketClassParams); - root = new BuildTree(objects, leafCapacity, arity, insertType, distanceMeasure, bucketDispatcher).getRoot(); + root = new BuildTree(objects, leafCapacity, arity, insertType, objectToNodeDistance, bucketDispatcher).getRoot(); } - public void approxKNN(ApproxKNNQueryOperation operation) { + public void approxKNN(ApproxKNNQueryOperation operation, double coefficient) { LocalAbstractObject queryObject = operation.getQueryObject(); + boolean limitVisitedLeaves = operation.getLocalSearchType() == Approximate.LocalSearchType.DATA_PARTITIONS; + PriorityQueue<ObjectToNodeDistanceRank> queue = new PriorityQueue<>(); queue.add(new ObjectToNodeDistanceRank(queryObject, root)); @@ -60,24 +68,28 @@ public class MHTree extends Algorithm implements Serializable { Node node = queue.poll().getNode(); if (node.isLeaf()) { - for (LocalAbstractObject object : node.getObjects()) { - if (operation.getAnswerCount() >= operation.getK() && object.getDistance(queryObject) > operation.getAnswerDistance()) - continue; + if (limitVisitedLeaves && statVisitedLeaves.get() == operation.getLocalSearchParam()) + break; - operation.addToAnswer(object); - } + statVisitedLeaves.add(); - if (operation.getAnswerCount() >= operation.getK()) - break; + for (LocalAbstractObject object : node.getObjects()) + if (!operation.isAnswerFull() || queryObject.getDistance(object) < operation.getAnswerDistance()) + operation.addToAnswer(object); } else { for (Node child : ((InternalNode) node).getChildren()) - queue.add(new ObjectToNodeDistanceRank(queryObject, child)); + if (!operation.isAnswerFull() || !isPrunable(child, queryObject, operation, coefficient)) + queue.add(new ObjectToNodeDistanceRank(queryObject, child)); } } operation.endOperation(); } + private boolean isPrunable(Node child, LocalAbstractObject queryObject, ApproxKNNQueryOperation operation, double coefficient) { + return operation.getAnswerDistance() * coefficient < child.getDistanceToNearestHullObject(queryObject); + } + public void insert(InsertOperation operation) throws BucketStorageException { LocalAbstractObject object = operation.getInsertedObject(); @@ -95,9 +107,11 @@ public class MHTree extends Algorithm implements Serializable { } public void printStatistics() { - List<Integer> objectCounts = root.getLeafNodesObjectCounts(); - - IntSummaryStatistics statistics = objectCounts.stream().mapToInt(Integer::valueOf).summaryStatistics(); + IntSummaryStatistics leafStatistics = bucketDispatcher + .getAllBuckets() + .stream() + .mapToInt(LocalBucket::getObjectCount) + .summaryStatistics(); System.out.println("--- STATISTICS ---"); @@ -108,31 +122,45 @@ public class MHTree extends Algorithm implements Serializable { System.out.println("Arity: " + arity); System.out.println("Leaf capacity: " + leafCapacity + " objects"); + System.out.println("Number of nodes in each level:"); + for (int level = 1; level <= root.getHeight() + 1; level++) { + System.out.println("- Level " + level + " -> " + root.getNodesOnLevel(level).size()); + } + + System.out.println("Number of hull objects in each level:"); + for (int level = 1; level <= root.getHeight() + 1; level++) { + Set<Node> levelNodes = root.getNodesOnLevel(level); + + System.out.println("- Level " + level + " -> " + levelNodes.stream().mapToInt(n -> n.getHullObjects().size()).summaryStatistics()); + } + System.out.println("\nHistogram of covered objects per level: "); System.out.println(Histogram.generate(root)); System.out.println("- LeafNodes -"); - System.out.println("Count: " + statistics.getCount()); - System.out.println("Minimum number of objects: " + statistics.getMin()); - System.out.println("Average number of objects: " + String.format("%.2f", statistics.getAverage())); - System.out.println("Maximum number of objects: " + statistics.getMax()); + System.out.println("Count: " + leafStatistics.getCount()); + System.out.println("Minimum number of objects: " + leafStatistics.getMin()); + System.out.println("Average number of objects: " + String.format("%.2f", leafStatistics.getAverage())); + System.out.println("Maximum number of objects: " + leafStatistics.getMax()); } - public double measureRecall(ApproxKNNQueryOperation approxKNNOperation) { - if (approxKNNOperation.getAnswerCount() == 0) return 0d; + public long getVisitedLeaves() { + return statVisitedLeaves.get(); + } - KNNQueryOperation KNNOperation = new KNNQueryOperation(approxKNNOperation.getQueryObject(), approxKNNOperation.getK()); - KNN(KNNOperation); + public void reset() { + statVisitedLeaves.reset(); + } - int trueKNNCount = 0; + public double measureRecall(ApproxKNNQueryOperation approxKNNQueryOperation) { + LocalAbstractObject queryObject = approxKNNQueryOperation.getQueryObject(); + int k = approxKNNQueryOperation.getK(); - for (Iterator<AbstractObject> KNNIt = KNNOperation.getAnswerObjects(); KNNIt.hasNext(); ) - for (Iterator<AbstractObject> approxKNNIt = approxKNNOperation.getAnswerObjects(); approxKNNIt.hasNext(); ) - if (KNNIt.next().getLocatorURI().equals(approxKNNIt.next().getLocatorURI())) - trueKNNCount++; + KNNQueryOperation knnQueryOperation = new KNNQueryOperation(queryObject, k); + KNN(knnQueryOperation); - return trueKNNCount / (double) KNNOperation.getAnswerCount(); + return Utils.measureRecall(approxKNNQueryOperation, knnQueryOperation); } private void KNN(KNNQueryOperation operation) { @@ -146,6 +174,8 @@ public class MHTree extends Algorithm implements Serializable { "leafCapacity=" + leafCapacity + ", arity=" + arity + ", root=" + root + + ", insertType" + insertType + + ", objectToNodeDistance" + objectToNodeDistance + '}'; } } diff --git a/src/mhtree/DistanceMeasure.java b/src/mhtree/ObjectToNodeDistance.java similarity index 67% rename from src/mhtree/DistanceMeasure.java rename to src/mhtree/ObjectToNodeDistance.java index 9486ec8..fc1d070 100644 --- a/src/mhtree/DistanceMeasure.java +++ b/src/mhtree/ObjectToNodeDistance.java @@ -3,6 +3,5 @@ package mhtree; public enum DistanceMeasure { NEAREST_HULL_OBJECT, FURTHEST_HULL_OBJECT, - SUM_OF_DISTANCES_TO_HULL_OBJECTS, - MEDOID + AVERAGE_DISTANCE } -- GitLab