diff --git a/src/mhtree/benchmarking/PerformanceMeasures.java b/src/mhtree/benchmarking/PerformanceMeasures.java index dee978afb2663777e0f7b16954dfaf8e30a87a92..8d61f34222b7bc3d1feb578fba10932b72e16373 100644 --- a/src/mhtree/benchmarking/PerformanceMeasures.java +++ b/src/mhtree/benchmarking/PerformanceMeasures.java @@ -4,7 +4,6 @@ import messif.algorithms.Algorithm; import messif.algorithms.AlgorithmMethodException; import messif.objects.util.DistanceRankedObject; import messif.objects.util.RankedAbstractObject; -import messif.operations.query.ApproxKNNQueryOperation; import messif.operations.query.KNNQueryOperation; import java.util.ArrayList; @@ -16,7 +15,7 @@ import java.util.stream.Collectors; public class PerformanceMeasures { - public static double measureErrorOnThePosition(ApproxKNNQueryOperation approxKNNQueryOperation, Algorithm tree) throws NoSuchMethodException, AlgorithmMethodException { + public static double measureErrorOnThePosition(KNNQueryOperation approxKNNQueryOperation, Algorithm tree) throws NoSuchMethodException, AlgorithmMethodException { KNNQueryOperation rankedObjects = new KNNQueryOperation(approxKNNQueryOperation.getQueryObject(), tree.getObjectCount()); tree.executeOperation(rankedObjects); @@ -40,7 +39,7 @@ public class PerformanceMeasures { } // comparing done based on distances, counts how many of the same distances of KNNQueryOperation were presents in the answer of ApproxKNNQueryOperation - public static double measureRecall(ApproxKNNQueryOperation approxKNNQueryOperation, Algorithm tree) throws NoSuchMethodException, AlgorithmMethodException { + public static double measureRecall(KNNQueryOperation approxKNNQueryOperation, Algorithm tree) throws NoSuchMethodException, AlgorithmMethodException { if (approxKNNQueryOperation.getAnswerCount() == 0) return 0d; KNNQueryOperation knnQueryOperation = new KNNQueryOperation(approxKNNQueryOperation.getQueryObject(), approxKNNQueryOperation.getK()); @@ -73,10 +72,11 @@ public class PerformanceMeasures { return trueKNNFoundCount / (double) knnQueryOperation.getAnswerCount(); } - public static double measureRecall(ApproxKNNQueryOperation approxKNNQueryOperation, Map<String, List<RankedAbstractObject>> trueKNN) { + public static double measureRecall(KNNQueryOperation approxKNNQueryOperation, Map<String, List<RankedAbstractObject>> trueKNN) { if (approxKNNQueryOperation.getAnswerCount() == 0) return 0d; - List<RankedAbstractObject> kNNObjects = trueKNN.get(approxKNNQueryOperation.getQueryObject().getLocatorURI()).subList(0, approxKNNQueryOperation.getK()); + List<RankedAbstractObject> kNNObjects = trueKNN.get(approxKNNQueryOperation.getQueryObject().getLocatorURI()) + .subList(0, approxKNNQueryOperation.getK()); Map<Float, Long> frequencyMap = kNNObjects .stream() diff --git a/src/mhtree/benchmarking/RunBenchmark.java b/src/mhtree/benchmarking/RunBenchmark.java index 9f1bfa9092ad8605995346637a15c178c3a72752..72fe19179de5155dd635a1552f452f02d664429b 100644 --- a/src/mhtree/benchmarking/RunBenchmark.java +++ b/src/mhtree/benchmarking/RunBenchmark.java @@ -4,7 +4,6 @@ import cz.muni.fi.disa.similarityoperators.cover.AbstractRepresentation; import messif.algorithms.AlgorithmMethodException; import messif.buckets.BucketStorageException; import messif.objects.LocalAbstractObject; -import messif.objects.impl.ObjectFloatVectorNeuralNetworkL2; import messif.objects.util.AbstractObjectList; import messif.objects.util.RankedAbstractObject; import messif.objects.util.StreamGenericAbstractObjectIterator; @@ -24,19 +23,19 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.List; +import java.util.Locale; import java.util.Map; -import java.util.logging.Level; -import java.util.logging.Logger; import java.util.stream.Collectors; import messif.algorithms.Algorithm; +import messif.objects.keys.AbstractObjectKey; import static mhtree.ObjectToNodeDistance.AVERAGE; import static mhtree.ObjectToNodeDistance.FURTHEST; import static mhtree.ObjectToNodeDistance.NEAREST; public class RunBenchmark { - public static void main(String[] args) throws IOException, AlgorithmMethodException, InstantiationException, NoSuchMethodException, BucketStorageException { - if (args.length != 7 && args.length != 10) { + public static void main(String[] args) throws IOException, AlgorithmMethodException, InstantiationException, NoSuchMethodException, BucketStorageException, ClassNotFoundException { + if (args.length != 8 && args.length != 11) { throw new IllegalArgumentException("Unexpected number of params"); } @@ -45,18 +44,21 @@ public class RunBenchmark { default -> false; }; + // e.g. messif.objects.impl.ObjectFloatVectorNeuralNetworkL2 + Class<? extends LocalAbstractObject> objClass = (Class<? extends LocalAbstractObject>) Class.forName(args[1]); + //Statistics.enableGlobally(); Statistics.disableGlobally(); AbstractRepresentation.PrecomputedDistances.COMPUTATION_THREADS = 16; - List<LocalAbstractObject> objects = loadDataset(args[1]); - int leafCapacity = Integer.parseInt(args[2]); - int nodeDegree = Integer.parseInt(args[3]); - List<LocalAbstractObject> queries = loadDataset(args[4]); + List<LocalAbstractObject> objects = loadDataset(args[2], objClass); + int leafCapacity = Integer.parseInt(args[3]); + int nodeDegree = Integer.parseInt(args[4]); + List<LocalAbstractObject> queries = loadDataset(args[5], objClass); - InsertType insertType = args[5].equals("INCREMENTAL") ? InsertType.INCREMENTAL : InsertType.GREEDY; + InsertType insertType = args[6].equals("INCREMENTAL") ? InsertType.INCREMENTAL : InsertType.GREEDY; - ObjectToNodeDistance objectToNodeDistance = switch (args[6]) { + ObjectToNodeDistance objectToNodeDistance = switch (args[7]) { case "FURTHEST" -> FURTHEST; case "AVERAGE" -> AVERAGE; default -> NEAREST; @@ -75,17 +77,17 @@ public class RunBenchmark { Arrays.asList(NEAREST), //Arrays.asList(NEAREST, FURTHEST, AVERAGE), queries, ks); - } else if (args.length == 7) { // Original M-tree + } else if (args.length == 8) { // Original M-tree percentageToRecallMTree(cfg, objects, queries, ks, 0, 0, null); } else { // Pivoting M-tree - List<LocalAbstractObject> pivots = loadDataset(args[9]); + List<LocalAbstractObject> pivots = loadDataset(args[10], objClass); percentageToRecallMTree(cfg, objects, queries, ks, - Integer.parseInt(args[7]), Integer.parseInt(args[8]), pivots); + Integer.parseInt(args[8]), Integer.parseInt(args[9]), pivots); } } @@ -202,8 +204,10 @@ public class RunBenchmark { mTree.insert(opIns); long buildingTime = System.currentTimeMillis() - buildingStartTimeStamp; + mTree.markObjectsWithBucketIds(); mTree.printStatistics(); - System.out.println("Fat factor: " + mTree.getFatFactor()); + //mTree.checkConsistency(); + //System.out.println("Fat factor: " + mTree.getFatFactor()); System.out.println("Building time: " + buildingTime + " msec"); System.out.println("kNN queries will be executed for k=" + Arrays.toString(ks)); @@ -257,16 +261,16 @@ public class RunBenchmark { "", String.valueOf(k), String.valueOf(percentage), - String.format("%.2f,%.2f,%.2f,%.2f", + String.format(Locale.ENGLISH, "%.2f,%.2f,%.2f,%.2f", recallStats.getMin(), recallStats.getAverage(), recallStats.getMedian(), recallStats.getMax()), - String.format("%.2f", timeStats.getAverage()))); + String.format(Locale.ENGLISH, "%.2f", timeStats.getAverage()))); minimalRecall = recallStats.getMin(); - if (minimalRecall == 1.0) - break; + //if (minimalRecall == 1.0) + // break; } } } @@ -285,17 +289,21 @@ public class RunBenchmark { alg.executeOperation(op); } catch (AlgorithmMethodException | NoSuchMethodException ex) { } KnnResultPair pair = new KnnResultPair(); + if (!AbstractObjectKey.class.equals(op.getQueryObject().getObjectKey().getClass())) + System.out.println("ERROR: Ground truth query has a non-expected abstract key! " + op.getQueryObject().getObjectKey()); pair.id = op.getQueryObject().getLocatorURI(); pair.kNNObjects = new ArrayList<>(maxK); for (RankedAbstractObject o : op) pair.kNNObjects.add(o); + if (pair.kNNObjects.size() != maxK) + System.out.println("ERROR: Ground truth for query " + pair.id + " constains " + pair.kNNObjects.size() + " objects, Expected " + maxK); return pair; }) .collect(Collectors.toMap(entry -> entry.id, entry -> entry.kNNObjects)); return kNNResults; } - private static List<LocalAbstractObject> loadDataset(String path) throws IOException { - return new AbstractObjectList<>(new StreamGenericAbstractObjectIterator<>(ObjectFloatVectorNeuralNetworkL2.class, path)); + private static List<LocalAbstractObject> loadDataset(String path, Class<? extends LocalAbstractObject> objCls) throws IOException { + return new AbstractObjectList<>(new StreamGenericAbstractObjectIterator<>(objCls, path)); } } diff --git a/src/mhtree/benchmarking/Stats.java b/src/mhtree/benchmarking/Stats.java index f9d793fe2d8d9d4c83df648123bd707f2e6ad748..8b760a4300851e45bae97517de685ca23a42e5f0 100644 --- a/src/mhtree/benchmarking/Stats.java +++ b/src/mhtree/benchmarking/Stats.java @@ -6,7 +6,7 @@ import java.util.List; public class Stats { List<Double> sortedData; - Stats(List<Double> data) { + public Stats(List<Double> data) { sortedData = data; Collections.sort(sortedData); }