Skip to content
Snippets Groups Projects
Commit c0aa6a1b authored by Vlastislav Dohnal's avatar Vlastislav Dohnal
Browse files

Minor refactoring of implementation and javadoc update

* RunBenchmark improved by calling Fat Factor and making M-tree code parallel.
parent 7a84b583
No related branches found
No related tags found
No related merge requests found
...@@ -5,18 +5,21 @@ package mhtree; ...@@ -5,18 +5,21 @@ package mhtree;
*/ */
public enum InsertType { public enum InsertType {
/** /**
* When the inserted object is not covered by a node, all objects under such node are retrieved, * When the inserted object is not covered by a node, all objects under such node are retrieved (recursively down to the buckets),
* and a new hull is built, replacing the current one. * and a new hull is built, replacing the current one.
*/ */
GREEDY, GREEDY,
/** /**
* When the inserted object is not covered by node, we iterate over hull objects beginning with the nearest one. * When the inserted object is not covered by node, we iterate over hull objects beginning with the nearest one.
* We try to replace the hull object by removing it from the hull and replacing it with the inserted object. * We try to replace an existing hull object by replacing it with the inserted object.
* If the removed hull object is covered by the new hull, we are done. * If the removed hull object is covered by the new hull, we are done.
* If no such hull object is found, the inserted object is simply added as a new hull object. * If no such hull object is found, the inserted object is simply added as a new hull object.
*/ */
INCREMENTAL, INCREMENTAL,
/** Take current hull objects and the newly inserted object and compute hull out of them.
* Vlasta: Is this identical to INCREMENTAL?
*/
ADD_HULL_OBJECT, ADD_HULL_OBJECT,
} }
...@@ -7,6 +7,7 @@ import messif.objects.LocalAbstractObject; ...@@ -7,6 +7,7 @@ import messif.objects.LocalAbstractObject;
import java.io.Serializable; import java.io.Serializable;
import java.util.Collection; import java.util.Collection;
import java.util.Collections;
import java.util.List; import java.util.List;
import java.util.stream.Collectors; import java.util.stream.Collectors;
...@@ -36,7 +37,8 @@ public abstract class Node implements Serializable { ...@@ -36,7 +37,8 @@ public abstract class Node implements Serializable {
this.objectToNodeDistance = objectToNodeDistance; this.objectToNodeDistance = objectToNodeDistance;
} }
protected static InternalNode createParent(List<Node> nodes, PrecomputedDistances distances, InsertType insertType, ObjectToNodeDistance objectToNodeDistance, MergingMethod mergingMethod) { protected static InternalNode createParent(List<Node> nodes, PrecomputedDistances distances, InsertType insertType,
ObjectToNodeDistance objectToNodeDistance, MergingMethod mergingMethod) {
List<LocalAbstractObject> objects = nodes List<LocalAbstractObject> objects = nodes
.stream() .stream()
.map(mergingMethod::getObjects) .map(mergingMethod::getObjects)
...@@ -169,7 +171,8 @@ public abstract class Node implements Serializable { ...@@ -169,7 +171,8 @@ public abstract class Node implements Serializable {
} }
private void insertIncremental(LocalAbstractObject object) { private void insertIncremental(LocalAbstractObject object) {
hull.addHullObject(object); //hull.addHullObject(object);
hull.setHullObjects((List<LocalAbstractObject>) Collections.singleton(object));
} }
private void insertHullRebuild(LocalAbstractObject object) { private void insertHullRebuild(LocalAbstractObject object) {
......
...@@ -52,7 +52,6 @@ public enum ObjectToNodeDistance { ...@@ -52,7 +52,6 @@ public enum ObjectToNodeDistance {
.min() .min()
.orElse(Double.MAX_VALUE); .orElse(Double.MAX_VALUE);
} }
}; };
/** /**
......
...@@ -45,4 +45,9 @@ public class ObjectToNodeDistanceRank implements Comparable<ObjectToNodeDistance ...@@ -45,4 +45,9 @@ public class ObjectToNodeDistanceRank implements Comparable<ObjectToNodeDistance
public Node getNode() { public Node getNode() {
return node; return node;
} }
public double getDistance() {
return distance;
}
} }
/*
* To change this license header, choose License Headers in Project Properties.
* To change this template file, choose Tools | Templates
* and open the template in the editor.
*/
package mhtree.benchmarking;
import java.util.List;
import messif.objects.util.RankedAbstractObject;
public class KnnResultPair {
public List<RankedAbstractObject> kNNObjects;
public String id;
}
...@@ -25,7 +25,10 @@ import java.util.Arrays; ...@@ -25,7 +25,10 @@ import java.util.Arrays;
import java.util.Collections; import java.util.Collections;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.logging.Level;
import java.util.logging.Logger;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import messif.algorithms.Algorithm;
import static mhtree.ObjectToNodeDistance.AVERAGE; import static mhtree.ObjectToNodeDistance.AVERAGE;
import static mhtree.ObjectToNodeDistance.FURTHEST; import static mhtree.ObjectToNodeDistance.FURTHEST;
...@@ -33,84 +36,83 @@ import static mhtree.ObjectToNodeDistance.NEAREST; ...@@ -33,84 +36,83 @@ import static mhtree.ObjectToNodeDistance.NEAREST;
public class RunBenchmark { public class RunBenchmark {
public static void main(String[] args) throws IOException, AlgorithmMethodException, InstantiationException, NoSuchMethodException, BucketStorageException { public static void main(String[] args) throws IOException, AlgorithmMethodException, InstantiationException, NoSuchMethodException, BucketStorageException {
if (args.length != 5) { if (args.length != 7) {
throw new IllegalArgumentException("Unexpected number of params"); throw new IllegalArgumentException("Unexpected number of params");
} }
Statistics.enableGlobally(); boolean isMHtree = switch (args[0]) {
case "MHULL-TREE" -> true;
default -> false;
};
//Statistics.enableGlobally();
Statistics.disableGlobally();
AbstractRepresentation.PrecomputedDistances.COMPUTATION_THREADS = 16; AbstractRepresentation.PrecomputedDistances.COMPUTATION_THREADS = 16;
List<LocalAbstractObject> objects = loadDataset(args[0]); List<LocalAbstractObject> objects = loadDataset(args[1]);
int leafCapacity = Integer.parseInt(args[1]); int leafCapacity = Integer.parseInt(args[2]);
int nodeDegree = Integer.parseInt(args[2]); int nodeDegree = Integer.parseInt(args[3]);
InsertType insertType = args[3].equals("INCREMENTAL") ? InsertType.INCREMENTAL : InsertType.GREEDY; List<LocalAbstractObject> queries = loadDataset(args[4]);
ObjectToNodeDistance objectToNodeDistance;
InsertType insertType = args[5].equals("INCREMENTAL") ? InsertType.INCREMENTAL : InsertType.GREEDY;
switch (args[4]) {
case "FURTHEST": ObjectToNodeDistance objectToNodeDistance = switch (args[6]) {
objectToNodeDistance = FURTHEST; case "FURTHEST" -> FURTHEST;
break; case "AVERAGE" -> AVERAGE;
case "AVERAGE": default -> NEAREST;
objectToNodeDistance = AVERAGE; };
break;
default: final MHTreeConfig cfg = new MHTreeConfig(
objectToNodeDistance = NEAREST; leafCapacity,
break; nodeDegree,
} insertType,
objectToNodeDistance
percentageToRecallMHTree(new MHTreeConfig(
leafCapacity,
nodeDegree,
insertType,
objectToNodeDistance
),
objects,
new int[]{1, 50}
); );
final int[] ks = new int[]{1, 3, 5, 10, 20, 50, 100};
if (isMHtree) {
percentageToRecallMHTree(cfg,
objects,
Arrays.asList(NEAREST),
//Arrays.asList(NEAREST, FURTHEST, AVERAGE),
queries, ks);
} else {
percentageToRecallMTree(cfg,
objects,
queries, ks);
}
} }
private static void percentageToRecallMHTree(MHTreeConfig config, List<LocalAbstractObject> objects, int[] ks) throws BucketStorageException, RuntimeException { private static void percentageToRecallMHTree(MHTreeConfig config, List<LocalAbstractObject> objects,
final List<ObjectToNodeDistance> distFuncs,
List<LocalAbstractObject> queries,
int[] ks) throws BucketStorageException, RuntimeException {
long buildingStartTimeStamp = System.currentTimeMillis();
MHTree mhTree = new MHTree.MHTreeBuilder(objects, config.leafCapacity, config.nodeDegree) MHTree mhTree = new MHTree.MHTreeBuilder(objects, config.leafCapacity, config.nodeDegree)
.objectToNodeDistance(config.objectToNodeDistance) .objectToNodeDistance(config.objectToNodeDistance)
.mergingMethod(MergingMethod.HULL_BASED_MERGE) .mergingMethod(MergingMethod.HULL_BASED_MERGE)
.build(); .build();
long buildinTime = System.currentTimeMillis() - buildingStartTimeStamp;
//OperationStatistics.getLocalThreadStatistics().printStatistics();
mhTree.printStatistics(); mhTree.printStatistics();
//System.out.println("Fat factor: " + mhTree.getFatFactor());
System.out.println("Building time: " + buildinTime + " msec");
for (ObjectToNodeDistance dist : Arrays.asList(NEAREST, FURTHEST, AVERAGE)) { System.out.println("kNN queries will be executed for k=" + Arrays.toString(ks));
for (ObjectToNodeDistance dist : distFuncs) {
System.gc(); System.gc();
mhTree.getNodes().forEach(node -> node.objectToNodeDistance = dist); mhTree.getNodes().forEach(node -> node.objectToNodeDistance = dist);
System.out.println("leafCapacity,nodeDegree,objectToNodeDistance,k,percentage,recall (min),recall (avg),recall (med),recall (max)"); System.out.println("leafCapacity,nodeDegree,objectToNodeDistance,k,percentage,recall (min),recall (avg),recall (med),recall (max), time (msec)");
int maxK = Arrays.stream(ks).max().getAsInt();
List<KNNQueryOperation> kNNOperations = objects
.parallelStream()
.map(object -> new KNNQueryOperation(object, maxK))
.collect(Collectors.toList());
class Pair { Map<String, List<RankedAbstractObject>> kNNResults = prepareGroundTruth(ks, queries, mhTree);
public List<RankedAbstractObject> kNNObjects;
public String id;
}
Map<String, List<RankedAbstractObject>> kNNResults = kNNOperations
.parallelStream()
.map(op -> {
mhTree.kNN(op);
Pair pair = new Pair();
pair.id = op.getQueryObject().getLocatorURI();
pair.kNNObjects = new ArrayList<>(maxK);
for (RankedAbstractObject o : op)
pair.kNNObjects.add(o);
return pair;
})
.collect(Collectors.toMap(entry -> entry.id, entry -> entry.kNNObjects));
for (int k : ks) { for (int k : ks) {
List<ApproxKNNQueryOperation> approxOperations = objects List<ApproxKNNQueryOperation> approxOperations = queries
.parallelStream() .parallelStream()
.map(object -> new ApproxKNNQueryOperation(object, k, 0, Approximate.LocalSearchType.PERCENTAGE, LocalAbstractObject.UNKNOWN_DISTANCE)) .map(object -> new ApproxKNNQueryOperation(object, k, 0, Approximate.LocalSearchType.PERCENTAGE, LocalAbstractObject.UNKNOWN_DISTANCE))
.collect(Collectors.toList()); .collect(Collectors.toList());
...@@ -119,18 +121,22 @@ public class RunBenchmark { ...@@ -119,18 +121,22 @@ public class RunBenchmark {
.forEach(op -> op.suppData = new SearchState(mhTree, op)); .forEach(op -> op.suppData = new SearchState(mhTree, op));
double minimalRecall = 0; double minimalRecall = 0;
int percentage = 0; for (int percentage = 0; percentage <= 100; percentage += 5) {
int percentageStep = 5; final int approxLimit = Math.round((float) objects.size() * (float) percentage / 100f);
while (minimalRecall != 1.0) {
approxOperations approxOperations
.parallelStream() .parallelStream()
.filter(op -> ((SearchState) op.suppData).recall != 1d) .filter(op -> ((SearchState) op.suppData).recall != 1d)
.forEach(op -> { .forEach(op -> {
SearchState searchState = (SearchState) op.suppData; SearchState searchState = (SearchState) op.suppData;
mhTree.approxKNN(op); try {
//mhTree.approxKNN(op);
mhTree.executeOperation(op);
} catch (AlgorithmMethodException | NoSuchMethodException ex) {
}
searchState.time = op.getParameter("OperationTime", Long.class);
searchState.recall = PerformanceMeasures.measureRecall(op, kNNResults); searchState.recall = PerformanceMeasures.measureRecall(op, kNNResults);
searchState.approximateState.limit += Math.round((float) objects.size() * (float) percentageStep / 100f); searchState.approximateState.limit = approxLimit;
op.resetAnswer();
}); });
Stats recallStats = new Stats( Stats recallStats = new Stats(
...@@ -139,6 +145,12 @@ public class RunBenchmark { ...@@ -139,6 +145,12 @@ public class RunBenchmark {
.map(op -> ((SearchState) op.suppData).recall) .map(op -> ((SearchState) op.suppData).recall)
.collect(Collectors.toList()) .collect(Collectors.toList())
); );
Stats timeStats = new Stats(
approxOperations
.stream()
.map(op -> (double)((SearchState) op.suppData).time)
.collect(Collectors.toList())
);
System.out.println(String.join(",", System.out.println(String.join(",",
String.valueOf(config.leafCapacity), String.valueOf(config.leafCapacity),
...@@ -150,80 +162,144 @@ public class RunBenchmark { ...@@ -150,80 +162,144 @@ public class RunBenchmark {
recallStats.getMin(), recallStats.getMin(),
recallStats.getAverage(), recallStats.getAverage(),
recallStats.getMedian(), recallStats.getMedian(),
recallStats.getMax()))); recallStats.getMax()),
String.format("%.2f", timeStats.getAverage())));
minimalRecall = recallStats.getMin(); minimalRecall = recallStats.getMin();
percentage += percentageStep; if (minimalRecall == 1.0)
break;
} }
} }
} }
} }
private static void percentageToRecallMTree(MHTreeConfig config, List<LocalAbstractObject> objects, int[] ks) throws NoSuchMethodException, AlgorithmMethodException, RuntimeException, InstantiationException { private static void percentageToRecallMTree(MHTreeConfig config, List<LocalAbstractObject> objects,
int numberOfObjects = objects.size(); List<LocalAbstractObject> queries,
int[] ks) throws NoSuchMethodException, AlgorithmMethodException, RuntimeException, InstantiationException {
long buildingStartTimeStamp = System.currentTimeMillis();
MTree mTree = new MTree(config.nodeDegree, config.leafCapacity); MTree mTree = new MTree(config.nodeDegree, config.leafCapacity);
Collections.shuffle(objects); Collections.shuffle(objects);
BulkInsertOperation op = new BulkInsertOperation(objects); BulkInsertOperation opIns = new BulkInsertOperation(objects);
mTree.insert(op); mTree.insert(opIns);
long buildingTime = System.currentTimeMillis() - buildingStartTimeStamp;
mTree.printStatistics(); mTree.printStatistics();
//System.out.println("Fat factor: " + mTree.getFatFactor());
System.out.println("Building time: " + buildingTime + " msec");
System.out.println("leafCapacity,nodeDegree,objectToNodeDistance,k,percentage,recall (min),recall (avg),recall (med),recall (max)"); System.out.println("kNN queries will be executed for k=" + Arrays.toString(ks));
System.out.println("leafCapacity,nodeDegree,objectToNodeDistance,k,percentage,recall (min),recall (avg),recall (med),recall (max), time (msec)");
double minimalRecall = 0; Map<String, List<RankedAbstractObject>> kNNResults = prepareGroundTruth(ks, queries, mTree);
int percentage = 0;
int percentageStep = 5;
// int numberOfQueries = queries.size();
for (int k : ks) { for (int k : ks) {
List<Double> recalls = new ArrayList<>(numberOfObjects); double minimalRecall = 0;
for (int i = 0; i < numberOfObjects; i++) { for (int percentage = 0; percentage <= 100; percentage += 5) {
recalls.add(0.0); final int approxLimit = percentage;
} List<ApproxKNNQueryOperation> approxOperations = queries
.parallelStream()
while (minimalRecall != 1.0) { .map(object -> new ApproxKNNQueryOperation(object, k, approxLimit, Approximate.LocalSearchType.PERCENTAGE, LocalAbstractObject.UNKNOWN_DISTANCE))
for (int i = 0; i < numberOfObjects; i++) { .collect(Collectors.toList());
if (recalls.get(i) != 1.0) { approxOperations
ApproxKNNQueryOperation operation = new ApproxKNNQueryOperation( .parallelStream()
objects.get(i), .forEach(op -> op.suppData = new SearchInfoState());
k,
percentage, approxOperations
Approximate.LocalSearchType.PERCENTAGE, .parallelStream()
LocalAbstractObject.UNKNOWN_DISTANCE .forEach(op -> {
); SearchInfoState searchState = (SearchInfoState) op.suppData;
try {
mTree.executeOperation(operation); //mhTree.approxKNN(op);
mTree.executeOperation(op);
recalls.set(i, PerformanceMeasures.measureRecall(operation, mTree)); } catch (AlgorithmMethodException | NoSuchMethodException ex) {
} }
} searchState.time = op.getParameter("OperationTime", Long.class);
searchState.recall = PerformanceMeasures.measureRecall(op, kNNResults);
Stats recallStats = new Stats(new ArrayList<>(recalls)); //searchState.approximateState.limit += Math.round((float) objects.size() * (float) percentageStep / 100f);
});
Stats recallStats = new Stats(
approxOperations
.stream()
.map(op -> ((SearchInfoState) op.suppData).recall)
.collect(Collectors.toList())
);
Stats timeStats = new Stats(
approxOperations
.stream()
.map(op -> (double)((SearchInfoState) op.suppData).time)
.collect(Collectors.toList())
);
// for (int i = 0; i < numberOfQueries; i++) {
// if (recalls.get(i) != 1.0) {
// ApproxKNNQueryOperation operation = new ApproxKNNQueryOperation(
// queries.get(i),
// k,
// percentage,
// Approximate.LocalSearchType.PERCENTAGE,
// LocalAbstractObject.UNKNOWN_DISTANCE
// );
//
// mTree.executeOperation(operation);
//
// times.set(i, (double)operation.getParameter("OperationTime", Long.class));
// recalls.set(i, PerformanceMeasures.measureRecall(operation, mTree));
// }
// }
System.out.println(String.join(",", System.out.println(String.join(",",
String.valueOf(config.leafCapacity), String.valueOf(config.leafCapacity),
String.valueOf(config.nodeDegree), String.valueOf(config.nodeDegree),
String.valueOf(config.objectToNodeDistance), "",
String.valueOf(k), String.valueOf(k),
String.valueOf(percentage), String.valueOf(percentage),
String.format("%.2f,%.2f,%.2f,%.2f", String.format("%.2f,%.2f,%.2f,%.2f",
recallStats.getMin(), recallStats.getMin(),
recallStats.getAverage(), recallStats.getAverage(),
recallStats.getMedian(), recallStats.getMedian(),
recallStats.getMax()))); recallStats.getMax()),
String.format("%.2f", timeStats.getAverage())));
minimalRecall = recallStats.getMin(); minimalRecall = recallStats.getMin();
percentage += percentageStep; if (minimalRecall == 1.0)
break;
} }
minimalRecall = 0;
percentage = 0;
} }
} }
private static Map<String, List<RankedAbstractObject>> prepareGroundTruth(int[] ks, List<LocalAbstractObject> queries, Algorithm alg) {
int maxK = Arrays.stream(ks).max().getAsInt();
List<KNNQueryOperation> kNNOperations = queries
.parallelStream()
.map(object -> new KNNQueryOperation(object, maxK))
.collect(Collectors.toList());
Map<String, List<RankedAbstractObject>> kNNResults = kNNOperations
.parallelStream()
.map((KNNQueryOperation op) -> {
try {
alg.executeOperation(op);
} catch (AlgorithmMethodException | NoSuchMethodException ex) { }
KnnResultPair pair = new KnnResultPair();
pair.id = op.getQueryObject().getLocatorURI();
pair.kNNObjects = new ArrayList<>(maxK);
for (RankedAbstractObject o : op)
pair.kNNObjects.add(o);
return pair;
})
.collect(Collectors.toMap(entry -> entry.id, entry -> entry.kNNObjects));
return kNNResults;
}
private static List<LocalAbstractObject> loadDataset(String path) throws IOException { private static List<LocalAbstractObject> loadDataset(String path) throws IOException {
return new AbstractObjectList<>(new StreamGenericAbstractObjectIterator<>(ObjectFloatVectorNeuralNetworkL2.class, path)); return new AbstractObjectList<>(new StreamGenericAbstractObjectIterator<>(ObjectFloatVectorNeuralNetworkL2.class, path));
} }
......
/*
* To change this license header, choose License Headers in Project Properties.
* To change this template file, choose Tools | Templates
* and open the template in the editor.
*/
package mhtree.benchmarking;
public class SearchInfoState {
public double recall;
public long time; // in msec
public SearchInfoState() {
this.recall = 0d;
this.time = 0;
}
}
...@@ -8,17 +8,15 @@ import mhtree.ObjectToNodeDistanceRank; ...@@ -8,17 +8,15 @@ import mhtree.ObjectToNodeDistanceRank;
import java.util.PriorityQueue; import java.util.PriorityQueue;
public class SearchState { public class SearchState extends SearchInfoState {
public PriorityQueue<ObjectToNodeDistanceRank> queue; public PriorityQueue<ObjectToNodeDistanceRank> queue;
public LocalAbstractObject queryObject; public LocalAbstractObject queryObject;
public ApproximateState approximateState; public ApproximateState approximateState;
public double recall;
public SearchState(MHTree tree, ApproxKNNQueryOperation operation) { public SearchState(MHTree tree, ApproxKNNQueryOperation operation) {
this.queue = new PriorityQueue<>(); this.queue = new PriorityQueue<>();
this.queue.add(new ObjectToNodeDistanceRank(operation.getQueryObject(), tree.getRoot(), operation.getK())); this.queue.add(new ObjectToNodeDistanceRank(operation.getQueryObject(), tree.getRoot(), operation.getK()));
this.queryObject = operation.getQueryObject(); this.queryObject = operation.getQueryObject();
this.approximateState = ApproximateState.create(operation, tree); this.approximateState = ApproximateState.create(operation, tree);
this.recall = 0d;
} }
} }
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment