diff --git a/src/mhtree/benchmarking/RunBenchmark.java b/src/mhtree/benchmarking/RunBenchmark.java index aa723ffad6df86409c775fb3d247b8ab6ae8ca7e..ebf9d50c576942cb8fb233576beadb3ae6b7fd23 100644 --- a/src/mhtree/benchmarking/RunBenchmark.java +++ b/src/mhtree/benchmarking/RunBenchmark.java @@ -15,7 +15,7 @@ import messif.operations.query.KNNQueryOperation; import messif.statistics.Statistics; import mhtree.InsertType; import mhtree.MHTree; -import mhtree.MergeType; +import mhtree.MergingMethod; import mhtree.ObjectToNodeDistance; import mtree.MTree; @@ -27,6 +27,10 @@ import java.util.List; import java.util.Map; import java.util.stream.Collectors; +import static mhtree.ObjectToNodeDistance.AVERAGE; +import static mhtree.ObjectToNodeDistance.FURTHEST; +import static mhtree.ObjectToNodeDistance.NEAREST; + public class RunBenchmark { public static void main(String[] args) throws IOException, AlgorithmMethodException, InstantiationException, NoSuchMethodException, BucketStorageException { if (args.length != 5) { @@ -44,13 +48,13 @@ public class RunBenchmark { switch (args[4]) { case "FURTHEST": - objectToNodeDistance = ObjectToNodeDistance.FURTHEST; + objectToNodeDistance = FURTHEST; break; case "AVERAGE": - objectToNodeDistance = ObjectToNodeDistance.AVERAGE; + objectToNodeDistance = AVERAGE; break; default: - objectToNodeDistance = ObjectToNodeDistance.NEAREST; + objectToNodeDistance = NEAREST; break; } @@ -61,90 +65,96 @@ public class RunBenchmark { objectToNodeDistance ), objects, - new int[]{1, 10, 50, 100} + new int[]{1, 50} ); } private static void percentageToRecallMHTree(MHTreeConfig config, List<LocalAbstractObject> objects, int[] ks) throws BucketStorageException, RuntimeException { - MHTree mTree = new MHTree.Builder(objects, config.leafCapacity, config.nodeDegree) + MHTree mhTree = new MHTree.MHTreeBuilder(objects, config.leafCapacity, config.nodeDegree) .objectToNodeDistance(config.objectToNodeDistance) - .mergeType(MergeType.REPRESENTATION_BASED) + .mergingMethod(MergingMethod.HULL_BASED_MERGE) .build(); - mTree.printStatistics(); + mhTree.printStatistics(); - System.out.println("leafCapacity,nodeDegree,objectToNodeDistance,k,percentage,recall (min),recall (avg),recall (med),recall (max)"); + for (ObjectToNodeDistance dist : Arrays.asList(NEAREST, FURTHEST, AVERAGE)) { + System.gc(); - int maxK = Arrays.stream(ks).max().getAsInt(); + mhTree.getNodes().forEach(node -> node.objectToNodeDistance = dist); - List<KNNQueryOperation> kNNOperations = objects - .parallelStream() - .map(object -> new KNNQueryOperation(object, maxK)) - .collect(Collectors.toList()); + System.out.println("leafCapacity,nodeDegree,objectToNodeDistance,k,percentage,recall (min),recall (avg),recall (med),recall (max)"); - class Pair { - public List<RankedAbstractObject> kNNObjects; - public String id; - } + int maxK = Arrays.stream(ks).max().getAsInt(); - Map<String, List<RankedAbstractObject>> kNNResults = kNNOperations - .parallelStream() - .map(op -> { - mTree.kNN(op); - Pair pair = new Pair(); - pair.id = op.getQueryObject().getLocatorURI(); - pair.kNNObjects = new ArrayList<>(maxK); - for (RankedAbstractObject o : op) - pair.kNNObjects.add(o); - return pair; - }) - .collect(Collectors.toMap(entry -> entry.id, entry -> entry.kNNObjects)); - - for (int k : ks) { - List<ApproxKNNQueryOperation> approxOperations = objects + List<KNNQueryOperation> kNNOperations = objects .parallelStream() - .map(object -> new ApproxKNNQueryOperation(object, k, 0, Approximate.LocalSearchType.PERCENTAGE, LocalAbstractObject.UNKNOWN_DISTANCE)) + .map(object -> new KNNQueryOperation(object, maxK)) .collect(Collectors.toList()); - approxOperations - .parallelStream() - .forEach(op -> op.suppData = new SearchState(mTree, op)); - double minimalRecall = 0; - int percentage = 0; - int percentageStep = 5; + class Pair { + public List<RankedAbstractObject> kNNObjects; + public String id; + } - while (minimalRecall != 1.0) { + Map<String, List<RankedAbstractObject>> kNNResults = kNNOperations + .parallelStream() + .map(op -> { + mhTree.kNN(op); + Pair pair = new Pair(); + pair.id = op.getQueryObject().getLocatorURI(); + pair.kNNObjects = new ArrayList<>(maxK); + for (RankedAbstractObject o : op) + pair.kNNObjects.add(o); + return pair; + }) + .collect(Collectors.toMap(entry -> entry.id, entry -> entry.kNNObjects)); + + for (int k : ks) { + List<ApproxKNNQueryOperation> approxOperations = objects + .parallelStream() + .map(object -> new ApproxKNNQueryOperation(object, k, 0, Approximate.LocalSearchType.PERCENTAGE, LocalAbstractObject.UNKNOWN_DISTANCE)) + .collect(Collectors.toList()); approxOperations .parallelStream() - .filter(op -> ((SearchState) op.suppData).recall != 1d) - .forEach(op -> { - SearchState searchState = (SearchState) op.suppData; - mTree.approxKNN(op); - searchState.recall = PerformanceMeasures.measureRecall(op, kNNResults); - searchState.approxState.limit += Math.round((float) objects.size() * (float) percentageStep / 100f); - }); - - Stats recallStats = new Stats( - approxOperations - .stream() - .map(op -> ((SearchState) op.suppData).recall) - .collect(Collectors.toList()) - ); - - System.out.println(String.join(",", - String.valueOf(config.leafCapacity), - String.valueOf(config.nodeDegree), - String.valueOf(config.objectToNodeDistance), - String.valueOf(k), - String.valueOf(percentage), - String.format("%.2f,%.2f,%.2f,%.2f", - recallStats.getMin(), - recallStats.getAverage(), - recallStats.getMedian(), - recallStats.getMax()))); - - minimalRecall = recallStats.getMin(); - percentage += percentageStep; + .forEach(op -> op.suppData = new SearchState(mhTree, op)); + + double minimalRecall = 0; + int percentage = 0; + int percentageStep = 5; + + while (minimalRecall != 1.0) { + approxOperations + .parallelStream() + .filter(op -> ((SearchState) op.suppData).recall != 1d) + .forEach(op -> { + SearchState searchState = (SearchState) op.suppData; + mhTree.approxKNN(op); + searchState.recall = PerformanceMeasures.measureRecall(op, kNNResults); + searchState.approximateState.limit += Math.round((float) objects.size() * (float) percentageStep / 100f); + }); + + Stats recallStats = new Stats( + approxOperations + .stream() + .map(op -> ((SearchState) op.suppData).recall) + .collect(Collectors.toList()) + ); + + System.out.println(String.join(",", + String.valueOf(config.leafCapacity), + String.valueOf(config.nodeDegree), + String.valueOf(dist), + String.valueOf(k), + String.valueOf(percentage), + String.format("%.2f,%.2f,%.2f,%.2f", + recallStats.getMin(), + recallStats.getAverage(), + recallStats.getMedian(), + recallStats.getMax()))); + + minimalRecall = recallStats.getMin(); + percentage += percentageStep; + } } } } @@ -214,7 +224,6 @@ public class RunBenchmark { } } - private static List<LocalAbstractObject> loadDataset(String path) throws IOException { return new AbstractObjectList<>(new StreamGenericAbstractObjectIterator<>(ObjectFloatVectorNeuralNetworkL2.class, path)); }