Skip to content
Snippets Groups Projects
Commit 1e0617f8 authored by Vlastislav Dohnal's avatar Vlastislav Dohnal
Browse files

Test pipeline fix - missing object in GT for large "k"s.

parent c0a25fd8
No related branches found
No related tags found
No related merge requests found
......@@ -73,10 +73,16 @@ public class PerformanceMeasures {
}
public static double measureRecall(KNNQueryOperation approxKNNQueryOperation, Map<String, List<RankedAbstractObject>> trueKNN) {
if (approxKNNQueryOperation.getAnswerCount() == 0) return 0d;
List<RankedAbstractObject> kNNObjects = trueKNN.get(approxKNNQueryOperation.getQueryObject().getLocatorURI())
.subList(0, approxKNNQueryOperation.getK());
if (approxKNNQueryOperation.getAnswerCount() == 0)
return 0d;
int k = approxKNNQueryOperation.getK();
final List<RankedAbstractObject> gt = trueKNN.get(approxKNNQueryOperation.getQueryObject().getLocatorURI());
if (k > gt.size()) {
System.err.println("Ground truch contains just " + gt.size() + " objects but approx has " + k + " objects. Query: " + approxKNNQueryOperation.getQueryObject().getLocatorURI());
k = gt.size();
}
List<RankedAbstractObject> kNNObjects = gt.subList(0, k);
Map<Float, Long> frequencyMap = kNNObjects
.stream()
......
......@@ -22,18 +22,25 @@ import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.logging.Level;
import java.util.logging.Logger;
import java.util.stream.Collectors;
import messif.algorithms.Algorithm;
import messif.objects.keys.AbstractObjectKey;
import messif.operations.AnswerType;
import messif.statistics.OperationStatistics;
import static mhtree.ObjectToNodeDistance.AVERAGE;
import static mhtree.ObjectToNodeDistance.FURTHEST;
import static mhtree.ObjectToNodeDistance.NEAREST;
public class RunBenchmark {
static final Logger log = Logger.getLogger("messif.algorithm");
public static void main(String[] args) throws IOException, AlgorithmMethodException, InstantiationException, NoSuchMethodException, BucketStorageException, ClassNotFoundException {
if (args.length != 8 && args.length != 11) {
throw new IllegalArgumentException("Unexpected number of params");
......@@ -47,7 +54,7 @@ public class RunBenchmark {
// e.g. messif.objects.impl.ObjectFloatVectorNeuralNetworkL2
Class<? extends LocalAbstractObject> objClass = (Class<? extends LocalAbstractObject>) Class.forName(args[1]);
//Statistics.enableGlobally();
// Statistics.enableGlobally();
Statistics.disableGlobally();
AbstractRepresentation.PrecomputedDistances.COMPUTATION_THREADS = 16;
......@@ -71,6 +78,7 @@ public class RunBenchmark {
objectToNodeDistance
);
final int[] ks = new int[]{1, 3, 5, 10, 20, 50, 100};
// final int[] ks = new int[]{3};
if (isMHtree) {
percentageToRecallMHTree(cfg,
objects,
......@@ -145,7 +153,8 @@ public class RunBenchmark {
}
searchState.time += op.getParameter("OperationTime", Long.class);
searchState.recall = PerformanceMeasures.measureRecall(op, kNNResults);
// Vlasta: NO!!! Because we test it incrementally!!!! See the search state of operation! op.resetAnswer();
log.log(Level.INFO, "{0} processed: {1}; Recall: {2}", new Object[]{mhTree.getName(), op.toString(), searchState.recall});
// Vlasta: NO op.resetAnswer() here!!! Because we test it incrementally!!!! See the search state of operation! op.resetAnswer();
});
Stats recallStats = new Stats(
......@@ -198,9 +207,11 @@ public class RunBenchmark {
Math.max(pmTreeNPD, pmTreeNHR), pmTreePivots.iterator(), pmTreeNPD, pmTreeNHR);
Collections.shuffle(objects);
System.out.println("Shuffling objects done. First is now " + objects.get(0).getLocatorURI());
BulkInsertOperation opIns = new BulkInsertOperation(objects);
//mTree.setMaxSpanningTree(1);
mTree.insert(opIns);
long buildingTime = System.currentTimeMillis() - buildingStartTimeStamp;
......@@ -220,11 +231,13 @@ public class RunBenchmark {
// int numberOfQueries = queries.size();
for (int k : ks) {
double minimalRecall = 0;
// for (int percentage = 55; percentage <= 55; percentage += 1) {
// for (int percentage = 0; percentage <= 5; percentage += 1) {
for (int percentage = 0; percentage <= 100; percentage += 5) {
final int approxLimit = percentage;
List<ApproxKNNQueryOperation> approxOperations = queries
.parallelStream()
.map(object -> new ApproxKNNQueryOperation(object, k, approxLimit, Approximate.LocalSearchType.PERCENTAGE, LocalAbstractObject.UNKNOWN_DISTANCE))
.map(object -> new ApproxKNNQueryOperation(object, k, AnswerType.ORIGINAL_OBJECTS, approxLimit, Approximate.LocalSearchType.PERCENTAGE, LocalAbstractObject.UNKNOWN_DISTANCE))
.collect(Collectors.toList());
approxOperations
.parallelStream()
......@@ -240,6 +253,23 @@ public class RunBenchmark {
}
searchState.time = op.getParameter("OperationTime", Long.class);
searchState.recall = PerformanceMeasures.measureRecall(op, kNNResults);
log.log(Level.INFO, "{0} processed: {1}; Recall: {2}", new Object[]{mTree.getName(), op.toString(), searchState.recall});
log.log(Level.INFO, "{0} processed: {1}; Answer: {2}", new Object[]{mTree.getName(), op.toString(), iterToString(op.getAnswer())});
OperationStatistics.getLocalThreadStatistics().printStatistics();
// if (searchState.recall != 1.0) {
// mTree.checkConsistency();
// try {
// mTree.storeToFile("mtree-bad.bin");
// } catch (IOException ex) {
// Logger.getLogger(RunBenchmark.class.getName()).log(Level.SEVERE, null, ex);
// }
// } else {
// try {
// mTree.storeToFile("mtree-ok.bin");
// } catch (IOException ex) {
// Logger.getLogger(RunBenchmark.class.getName()).log(Level.SEVERE, null, ex);
// }
// }
});
Stats recallStats = new Stats(
......@@ -274,13 +304,23 @@ public class RunBenchmark {
}
}
}
private static String iterToString(Iterator it) {
StringBuilder sb = new StringBuilder();
while (it.hasNext()) {
if (sb.length() > 0)
sb.append(", ");
sb.append(it.next());
}
return sb.toString();
}
private static Map<String, List<RankedAbstractObject>> prepareGroundTruth(int[] ks, List<LocalAbstractObject> queries, Algorithm alg) {
int maxK = Arrays.stream(ks).max().getAsInt();
List<KNNQueryOperation> kNNOperations = queries
.parallelStream()
.map(object -> new KNNQueryOperation(object, maxK))
.map(object -> new KNNQueryOperation(object, maxK, AnswerType.ORIGINAL_OBJECTS))
.collect(Collectors.toList());
Map<String, List<RankedAbstractObject>> kNNResults = kNNOperations
.parallelStream()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment