From dd40600b22cb8c133cc41fed1f4370ed59166e09 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?David=20Proch=C3=A1zka?= <david@prochazka.dev> Date: Wed, 21 Apr 2021 11:51:43 +0200 Subject: [PATCH] ADD: M-Tree benchmark logic --- src/mhtree/Node.java | 14 ++--- src/mhtree/benchmarking/RunBenchmark.java | 76 ++++++++++++++++++++++- 2 files changed, 80 insertions(+), 10 deletions(-) diff --git a/src/mhtree/Node.java b/src/mhtree/Node.java index ddaef35..600e80f 100644 --- a/src/mhtree/Node.java +++ b/src/mhtree/Node.java @@ -18,8 +18,8 @@ public abstract class Node implements Serializable { */ private static final long serialVersionUID = 420L; - private final InsertType INSERT_TYPE; - private final ObjectToNodeDistance OBJECT_TO_NODE_DISTANCE; + private final InsertType insertType; + private final ObjectToNodeDistance objectToNodeDistance; private HullOptimizedRepresentationV3 hull; @@ -27,8 +27,8 @@ public abstract class Node implements Serializable { this.hull = new HullOptimizedRepresentationV3(distances); this.hull.build(); - this.INSERT_TYPE = insertType; - this.OBJECT_TO_NODE_DISTANCE = objectToNodeDistance; + this.insertType = insertType; + this.objectToNodeDistance = objectToNodeDistance; } protected static InternalNode createParent(List<Node> nodes, PrecomputedDistances distances, InsertType insertType, ObjectToNodeDistance objectToNodeDistance, MergeType mergeType) { @@ -70,11 +70,11 @@ public abstract class Node implements Serializable { } protected double getDistance(LocalAbstractObject object) { - return OBJECT_TO_NODE_DISTANCE.getDistance(object, this); + return objectToNodeDistance.getDistance(object, this); } protected double getDistance(LocalAbstractObject object, PrecomputedDistances distances) { - return OBJECT_TO_NODE_DISTANCE.getDistance(object, this, distances); + return objectToNodeDistance.getDistance(object, this, distances); } protected double getDistanceToNearest(LocalAbstractObject object) { @@ -100,7 +100,7 @@ public abstract class Node implements Serializable { protected void addObjectIntoHull(LocalAbstractObject object, PrecomputedDistances distances) { if (isCovered(object, distances)) return; - if (INSERT_TYPE == InsertType.INCREMENTAL) { + if (insertType == InsertType.INCREMENTAL) { hull.addHullObject(object); return; } diff --git a/src/mhtree/benchmarking/RunBenchmark.java b/src/mhtree/benchmarking/RunBenchmark.java index 1dc0fc4..aa723ff 100644 --- a/src/mhtree/benchmarking/RunBenchmark.java +++ b/src/mhtree/benchmarking/RunBenchmark.java @@ -1,6 +1,7 @@ package mhtree.benchmarking; import cz.muni.fi.disa.similarityoperators.cover.AbstractRepresentation; +import messif.algorithms.AlgorithmMethodException; import messif.buckets.BucketStorageException; import messif.objects.LocalAbstractObject; import messif.objects.impl.ObjectFloatVectorNeuralNetworkL2; @@ -8,6 +9,7 @@ import messif.objects.util.AbstractObjectList; import messif.objects.util.RankedAbstractObject; import messif.objects.util.StreamGenericAbstractObjectIterator; import messif.operations.Approximate; +import messif.operations.data.BulkInsertOperation; import messif.operations.query.ApproxKNNQueryOperation; import messif.operations.query.KNNQueryOperation; import messif.statistics.Statistics; @@ -15,16 +17,18 @@ import mhtree.InsertType; import mhtree.MHTree; import mhtree.MergeType; import mhtree.ObjectToNodeDistance; +import mtree.MTree; import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; +import java.util.Collections; import java.util.List; import java.util.Map; import java.util.stream.Collectors; public class RunBenchmark { - public static void main(String[] args) throws IOException, BucketStorageException { + public static void main(String[] args) throws IOException, AlgorithmMethodException, InstantiationException, NoSuchMethodException, BucketStorageException { if (args.length != 5) { throw new IllegalArgumentException("Unexpected number of params"); } @@ -50,7 +54,7 @@ public class RunBenchmark { break; } - percentageToRecall(new MHTreeConfig( + percentageToRecallMHTree(new MHTreeConfig( leafCapacity, nodeDegree, insertType, @@ -61,7 +65,7 @@ public class RunBenchmark { ); } - private static void percentageToRecall(MHTreeConfig config, List<LocalAbstractObject> objects, int[] ks) throws BucketStorageException, RuntimeException { + private static void percentageToRecallMHTree(MHTreeConfig config, List<LocalAbstractObject> objects, int[] ks) throws BucketStorageException, RuntimeException { MHTree mTree = new MHTree.Builder(objects, config.leafCapacity, config.nodeDegree) .objectToNodeDistance(config.objectToNodeDistance) .mergeType(MergeType.REPRESENTATION_BASED) @@ -145,6 +149,72 @@ public class RunBenchmark { } } + private static void percentageToRecallMTree(MHTreeConfig config, List<LocalAbstractObject> objects, int[] ks) throws NoSuchMethodException, AlgorithmMethodException, RuntimeException, InstantiationException { + int numberOfObjects = objects.size(); + + MTree mTree = new MTree(config.nodeDegree, config.leafCapacity); + + Collections.shuffle(objects); + + BulkInsertOperation op = new BulkInsertOperation(objects); + + mTree.insert(op); + + mTree.printStatistics(); + + System.out.println("leafCapacity,nodeDegree,objectToNodeDistance,k,percentage,recall (min),recall (avg),recall (med),recall (max)"); + + double minimalRecall = 0; + int percentage = 0; + int percentageStep = 5; + + for (int k : ks) { + List<Double> recalls = new ArrayList<>(numberOfObjects); + for (int i = 0; i < numberOfObjects; i++) { + recalls.add(0.0); + } + + while (minimalRecall != 1.0) { + for (int i = 0; i < numberOfObjects; i++) { + if (recalls.get(i) != 1.0) { + ApproxKNNQueryOperation operation = new ApproxKNNQueryOperation( + objects.get(i), + k, + percentage, + Approximate.LocalSearchType.PERCENTAGE, + LocalAbstractObject.UNKNOWN_DISTANCE + ); + + mTree.executeOperation(operation); + + recalls.set(i, PerformanceMeasures.measureRecall(operation, mTree)); + } + } + + Stats recallStats = new Stats(new ArrayList<>(recalls)); + + System.out.println(String.join(",", + String.valueOf(config.leafCapacity), + String.valueOf(config.nodeDegree), + String.valueOf(config.objectToNodeDistance), + String.valueOf(k), + String.valueOf(percentage), + String.format("%.2f,%.2f,%.2f,%.2f", + recallStats.getMin(), + recallStats.getAverage(), + recallStats.getMedian(), + recallStats.getMax()))); + + minimalRecall = recallStats.getMin(); + percentage += percentageStep; + } + + minimalRecall = 0; + percentage = 0; + } + } + + private static List<LocalAbstractObject> loadDataset(String path) throws IOException { return new AbstractObjectList<>(new StreamGenericAbstractObjectIterator<>(ObjectFloatVectorNeuralNetworkL2.class, path)); } -- GitLab