From b65bdc266ddb6b6c079b7bc7fd292cd0bd823961 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?David=20Proch=C3=A1zka?= <david@prochazka.dev> Date: Thu, 8 Apr 2021 15:25:09 +0200 Subject: [PATCH] ADD: percentage to recall measurement --- .../benchmarking/PerformanceMeasures.java | 40 +---- src/mhtree/benchmarking/RunBenchmark.java | 137 +++++------------- 2 files changed, 44 insertions(+), 133 deletions(-) diff --git a/src/mhtree/benchmarking/PerformanceMeasures.java b/src/mhtree/benchmarking/PerformanceMeasures.java index 3e8d97b..767170d 100644 --- a/src/mhtree/benchmarking/PerformanceMeasures.java +++ b/src/mhtree/benchmarking/PerformanceMeasures.java @@ -46,15 +46,16 @@ public class PerformanceMeasures { KNNQueryOperation knnQueryOperation = new KNNQueryOperation(approxKNNQueryOperation.getQueryObject(), approxKNNQueryOperation.getK()); tree.executeOperation(knnQueryOperation); - List<RankedAbstractObject> objects = new ArrayList<>(knnQueryOperation.getAnswerCount()); + List<RankedAbstractObject> kNNObjects = new ArrayList<>(knnQueryOperation.getAnswerCount()); for (RankedAbstractObject object : knnQueryOperation) - objects.add(object); + kNNObjects.add(object); - Map<Float, Long> frequencyMap = objects.parallelStream() + Map<Float, Long> frequencyMap = kNNObjects + .stream() .map(DistanceRankedObject::getDistance) .collect(Collectors.groupingBy(Function.identity(), Collectors.counting())); - long trueCount = 0; + long trueKNNFoundCount = 0; for (RankedAbstractObject approxObject : approxKNNQueryOperation) { float distance = approxObject.getDistance(); @@ -65,37 +66,10 @@ public class PerformanceMeasures { } else { frequencyMap.replace(distance, count - 1); } - trueCount++; + trueKNNFoundCount++; } } - return trueCount / (double) knnQueryOperation.getAnswerCount(); - } - - // comparing done based on distances, counts how many of the same distances of KNNQueryOperation were presents in the answer of ApproxKNNQueryOperation - public static double measureRecall(ApproxKNNQueryOperation approxKNNQueryOperation, HashMap<String, List<RankedAbstractObject>> kNNs) throws NoSuchMethodException, AlgorithmMethodException { - if (approxKNNQueryOperation.getAnswerCount() == 0) return 0d; - - Map<Float, Long> frequencyMap = kNNs.get(approxKNNQueryOperation.getQueryObject().getLocatorURI()).subList(0, approxKNNQueryOperation.getK()) - .parallelStream() - .map(DistanceRankedObject::getDistance) - .collect(Collectors.groupingBy(Function.identity(), Collectors.counting())); - - long trueCount = 0; - - for (RankedAbstractObject approxObject : approxKNNQueryOperation) { - float distance = approxObject.getDistance(); - if (frequencyMap.containsKey(distance)) { - long count = frequencyMap.get(distance); - if (count == 1) { - frequencyMap.remove(distance); - } else { - frequencyMap.replace(distance, count - 1); - } - trueCount++; - } - } - - return trueCount / (double) approxKNNQueryOperation.getK(); + return trueKNNFoundCount / (double) knnQueryOperation.getAnswerCount(); } } diff --git a/src/mhtree/benchmarking/RunBenchmark.java b/src/mhtree/benchmarking/RunBenchmark.java index 1c59f28..2ed5262 100644 --- a/src/mhtree/benchmarking/RunBenchmark.java +++ b/src/mhtree/benchmarking/RunBenchmark.java @@ -6,53 +6,28 @@ import messif.buckets.BucketStorageException; import messif.objects.LocalAbstractObject; import messif.objects.impl.ObjectFloatVectorNeuralNetworkL2; import messif.objects.util.AbstractObjectList; -import messif.objects.util.RankedAbstractObject; import messif.objects.util.StreamGenericAbstractObjectIterator; import messif.operations.Approximate; import messif.operations.query.ApproxKNNQueryOperation; -import messif.operations.query.KNNQueryOperation; +import messif.statistics.Statistics; import mhtree.InsertType; import mhtree.MHTree; import mhtree.ObjectToNodeDistance; -import mhtree.SearchState; import java.io.IOException; import java.util.ArrayList; -import java.util.Arrays; -import java.util.HashMap; import java.util.List; -import java.util.stream.Collectors; public class RunBenchmark { - private static final String dataset = "./data/middle_c.txt"; - - public static void main(String[] args) throws IOException, NoSuchMethodException, AlgorithmMethodException, BucketStorageException, ClassNotFoundException { - if (!(args.length == 0 || args.length == 5)) { - throw new IllegalArgumentException("unexpected number of params"); + public static void main(String[] args) throws IOException, NoSuchMethodException, AlgorithmMethodException, BucketStorageException, InstantiationException { + if (args.length != 5) { + throw new IllegalArgumentException("Unexpected number of params"); } + Statistics.enableGlobally(); AbstractRepresentation.PrecomputedDistances.COMPUTATION_THREADS = 16; - List<LocalAbstractObject> objects; - - if (args.length != 5) { - objects = loadDataset(RunBenchmark.dataset); - - DCtoRecall( - new MHTreeConfig( - 10, - 2, - InsertType.GREEDY, - ObjectToNodeDistance.NEAREST - ), - objects, - new int[]{1, 50, 100} - ); - - return; - } - - objects = loadDataset(args[0]); + List<LocalAbstractObject> objects = loadDataset(args[0]); int leafCapacity = Integer.parseInt(args[1]); int nodeDegree = Integer.parseInt(args[2]); InsertType insertType = args[3].equals("INCREMENTAL") ? InsertType.INCREMENTAL : InsertType.GREEDY; @@ -70,90 +45,52 @@ public class RunBenchmark { break; } -// Collections.shuffle(objects); - - DCtoRecall(new MHTreeConfig( + percentageToRecall(new MHTreeConfig( leafCapacity, nodeDegree, insertType, objectToNodeDistance ), objects, - new int[]{1, 50} + new int[]{1, 10, 25, 50, 100} ); } - private static void DCtoRecall(MHTreeConfig config, List<LocalAbstractObject> objects, int[] ks) throws NoSuchMethodException, AlgorithmMethodException, BucketStorageException, RuntimeException { - MHTree mhTree = new MHTree.Builder(objects, config.leafCapacity, config.nodeDegree) - .insertType(config.insertType) + private static void percentageToRecall(MHTreeConfig config, List<LocalAbstractObject> objects, int[] ks) throws NoSuchMethodException, AlgorithmMethodException, BucketStorageException, RuntimeException { + int numberOfObjects = objects.size(); + + MHTree mTree = new MHTree.Builder(objects, config.leafCapacity, config.nodeDegree) .objectToNodeDistance(config.objectToNodeDistance) .build(); - mhTree.printStatistics(); - - int maxK = Arrays.stream(ks) - .max() - .orElseThrow(() -> new RuntimeException("bad max k")); - - // Precomputed NNs for each object - HashMap<String, List<RankedAbstractObject>> URIToKNNs = new HashMap<>(objects.size()); - - for (LocalAbstractObject object : objects) { - KNNQueryOperation knnQueryOperation = new KNNQueryOperation(object, maxK); - mhTree.executeOperation(knnQueryOperation); - - List<RankedAbstractObject> rankedKNNObjects = new ArrayList<>(knnQueryOperation.getAnswerCount()); - for (RankedAbstractObject knnObj : knnQueryOperation) { - rankedKNNObjects.add(knnObj); - } - - URIToKNNs.put(object.getLocatorURI(), rankedKNNObjects); - } + mTree.printStatistics(); - System.out.println("leafCapacity,nodeDegree,objectToNodeDistance,k,distanceComputations,recall (min),recall (avg),recall (med),recall (max)"); + System.out.println("leafCapacity,nodeDegree,objectToNodeDistance,k,percentage,recall (min),recall (avg),recall (med),recall (max)"); - double minRecall = 0; - long distanceComputations = 0; + double minimalRecall = 0; + int percentage = 0; + int percentageStep = 5; for (int k : ks) { - List<ApproxKNNQueryOperation> operations = objects - .stream() - .map(object -> new ApproxKNNQueryOperation( - object, - k, - 1000, - Approximate.LocalSearchType.ABS_DC_COUNT, - LocalAbstractObject.UNKNOWN_DISTANCE)) - .collect(Collectors.toList()); - - operations.forEach(operation -> operation.suppData = new SearchState(null, null, false)); - - System.out.println( - String.join(",", - String.valueOf(config.leafCapacity), - String.valueOf(config.nodeDegree), - String.valueOf(config.objectToNodeDistance), - String.valueOf(k), - String.valueOf(distanceComputations), - String.format("%.2f,%.2f,%.2f,%.2f", 0f, 0f, 0f, 0f) - ) - ); - - distanceComputations += 1000; - - List<Double> recalls = new ArrayList<>(operations.size()); - for (int i = 0; i < operations.size(); i++) { - recalls.add(0d); + List<Double> recalls = new ArrayList<>(numberOfObjects); + for (int i = 0; i < numberOfObjects; i++) { + recalls.add(0.0); } - while (minRecall != 1) { - for (int i = 0; i < operations.size(); i++) { - if (recalls.get(i) != 1d) { - ApproxKNNQueryOperation operation = operations.get(i); + while (minimalRecall != 1.0) { + for (int i = 0; i < numberOfObjects; i++) { + if (recalls.get(i) != 1.0) { + ApproxKNNQueryOperation operation = new ApproxKNNQueryOperation( + objects.get(i), + k, + percentage, + Approximate.LocalSearchType.PERCENTAGE, + LocalAbstractObject.UNKNOWN_DISTANCE + ); - mhTree.executeOperation(operation); + mTree.executeOperation(operation); - recalls.set(i, PerformanceMeasures.measureRecall(operation, URIToKNNs)); + recalls.set(i, PerformanceMeasures.measureRecall(operation, mTree)); } } @@ -164,19 +101,19 @@ public class RunBenchmark { String.valueOf(config.nodeDegree), String.valueOf(config.objectToNodeDistance), String.valueOf(k), - String.valueOf(distanceComputations), + String.valueOf(percentage), String.format("%.2f,%.2f,%.2f,%.2f", recallStats.getMin(), recallStats.getAverage(), recallStats.getMedian(), recallStats.getMax()))); - minRecall = recallStats.getMin(); - distanceComputations += 1000; + minimalRecall = recallStats.getMin(); + percentage += percentageStep; } - minRecall = 0; - distanceComputations = 0; + minimalRecall = 0; + percentage = 0; } } -- GitLab